Skip to content

[ddp] must set static_graph=False when running with dynamo #93672

@davidberard98

Description

@davidberard98

static_graph docs from the pytorch docs:

When set to True, DDP knows the trained graph is static. Static graph means 1) The set of used and unused parameters will not change during the whole training loop; in this case, it does not matter whether users set find_unused_parameters = True or not. 2) How the graph is trained will not change during the whole training loop (meaning there is no control flow depending on iterations). When static_graph is set to be True, DDP will support cases that can not be supported in the past: 1) Reentrant backwards. 2) Activation checkpointing multiple times. 3) Activation checkpointing when model has unused parameters. 4) There are model parameters that are outside of forward function. 5) Potentially improve performance when there are unused parameters, as DDP will not search graph in each iteraton to detect unused parameters when static_graph is set to be True. To check whether you can set static_graph to be True, one way is to check ddp logging data at the end of your previous model training, if ddp_logging_data.get("can_set_static_graph") == True, mostly you can set static_graph = True as well.

When training resnet50 with eager + DDP on torchbench, we can set static_graph=True.

But when training with torchdynamo + inductor + DDP, we need to set static_graph=False, otherwise we get this error:

Traceback (most recent call last):
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/submitit/core/_submit.py", line 11, in <module>
    submitit_main()
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/submitit/core/submission.py", line 72, in submitit_main
    process_job(args.folder)
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/submitit/core/submission.py", line 65, in process_job
    raise error
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/submitit/core/submission.py", line 54, in process_job
    result = delayed.result()
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/submitit/core/utils.py", line 133, in result
    self._result = self.function(*self.args, **self.kwargs)
  File "ddp_experiments.py", line 151, in __call__
    return trainer_class(self.args, model_class, model_args=self.model_args).measure()
  File "/fsx/users/dberard/scratch-local/bench-fast/benchmark/torchbenchmark/util/distributed/core_model/trainer.py", line 79, in measure
    self.benchmark.invoke()
  File "/fsx/users/dberard/scratch-local/bench-fast/benchmark/torchbenchmark/util/model.py", line 190, in invoke
    self.train()
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torchdynamo/eval_frame.py", line 166, in _fn
    return fn(*args, **kwargs)
  File "/fsx/users/dberard/scratch-local/bench-fast/benchmark/torchbenchmark/util/framework/vision/model_factory.py", line 75, in train
    def train(self):
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/functorch/_src/monkey_patching.py", line 77, in _backward
    return _old_backward(*args, **kwargs)
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torch/_tensor.py", line 484, in backward
    torch.autograd.backward(
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torch/autograd/__init__.py", line 191, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Your training graph has changed in this iteration, e.g., one parameter is unused in first iteration, but then got used in the second iteration. this is not compatible with static_graph set to True.
Exception raised from autograd_hook at /scratch/dberard/bench-fast/pytorch/torch/csrc/distributed/c10d/reducer.cpp:668 (most recent call first):
frame #0: <unknown function> + 0x104595 (0x7fc383030595 in /data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torch/lib/libc10.so)

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng @kiukchung @d4l3k @LucasLLC @soumith @ngimel

Metadata

Metadata

Assignees

Labels

module: ddpIssues/PRs related distributed data parallel trainingmodule: dynamooncall: distributedAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions