diff --git a/test/test_distributed.py b/test/test_distributed.py new file mode 100644 index 0000000000..1f7ed1de88 --- /dev/null +++ b/test/test_distributed.py @@ -0,0 +1,181 @@ +#!/usr/bin/env pytest +import os +from unittest.mock import patch + +import pytest +import torch +import torch.distributed as dist +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.nn.parallel import DistributedDataParallel as DDP + +import torchdynamo +from torchdynamo import config +from torchdynamo.testing import same + + +class ToyModel(nn.Module): + def __init__(self, in_feat=10, hidden_feat=5000, num_hidden=2, out_feat=5): + super().__init__() + self.net = nn.Sequential( + *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] + + [nn.Linear(5000, 5000), nn.ReLU()] * num_hidden + + [nn.Linear(5000, 5), nn.ReLU()] + ) + + def forward(self, inputs): + return self.net(inputs) + + +class CheckSplitsCompiler: + def __init__(self): + self.compiler_called = 0 + + def compile_fn(self, gm, example_inputs): + self.compiler_called += 1 + return gm + + +class TestDistributed(torchdynamo.testing.TestCase): + """ + Test harness initializes dist process group + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + # _exit_stack is set up in TestCase + cls._exit_stack.enter_context( + patch.dict( + os.environ, + { + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12355", + }, + ) + ) + cls.rank = 0 + cls.device = f"cuda:{cls.rank}" + dist.init_process_group("gloo", rank=cls.rank, world_size=1) + + @classmethod + def tearDownClass(cls): + dist.destroy_process_group() + super().tearDownClass() + + def get_model(self): + m = ToyModel().to(self.device) + inputs = torch.randn(20, 10).to(self.device) + outputs = m(inputs) + return m, inputs, outputs + + # fails with assertion in aot_autograd + @pytest.mark.xfail + @patch.object(config, "optimize_ddp", False) + def test_ddp_baseline_aot_eager(self): + m, inputs, correct_outputs = self.get_model() + ddp_m = DDP(m, device_ids=[self.rank]) + ddp_m = torchdynamo.optimize("aot_eager")(ddp_m) + outputs = ddp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + # segfaults + @pytest.mark.skip + @patch.object(config, "optimize_ddp", False) + def test_ddp_baseline_inductor(self): + m, inputs, correct_outputs = self.get_model() + ddp_m = DDP(m, device_ids=[self.rank]) + ddp_m = torchdynamo.optimize("inductor")(ddp_m) + outputs = ddp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + # fails with assertion in aot_autograd + @pytest.mark.xfail + @patch.object(config, "optimize_ddp", False) + def test_fsdp_baseline_aot_eager(self): + m, inputs, correct_outputs = self.get_model() + fsdp_m = FSDP(m, device_id=self.rank) + fsdp_m = torchdynamo.optimize("aot_eager")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + # segfaults + @pytest.mark.skip + @patch.object(config, "optimize_ddp", False) + def test_fsdp_baseline_inductor(self): + m, inputs, correct_outputs = self.get_model() + fsdp_m = FSDP(m, device_id=self.rank) + fsdp_m = torchdynamo.optimize("inductor")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @pytest.mark.skipif( + not hasattr(DDP, "_get_active_ddp_module"), + reason="requires pytorch landing in parallel", + ) + @patch.object(config, "optimize_ddp", True) + def test_graph_split(self): + """ + Just ensures that the appropriate number of splits happen (based on + bucket size and model parameters) - verifies the number of times + the user-provided compiler is called by the DDPOptimizer which is + doing the graph splitting + """ + m, inputs, correct_outputs = self.get_model() + ddp_m = DDP(m, device_ids=[self.rank], bucket_cap_mb=25) + + check_splits_compiler = CheckSplitsCompiler() + + @torchdynamo.optimize(check_splits_compiler.compile_fn) + def opt_fn(inputs): + return ddp_m(inputs) + + opt_outputs = opt_fn(inputs) + self.assertTrue(same(correct_outputs, opt_outputs)) + self.assertEqual(check_splits_compiler.compiler_called, 3) + + @pytest.mark.skipif( + not hasattr(DDP, "_get_active_ddp_module"), + reason="requires pytorch landing in parallel", + ) + @patch.object(config, "optimize_ddp", True) + def test_no_split(self): + """ + Ensures the DDPOptimizer returns a correct, compiled module without + introducing graph splits. (Based on model parmeters fitting in the bucket) + """ + m, inputs, correct_outputs = self.get_model() + ddp_m = DDP(m, device_ids=[self.rank], bucket_cap_mb=250) + + check_splits_compiler = CheckSplitsCompiler() + + @torchdynamo.optimize(check_splits_compiler.compile_fn) + def opt_fn(inputs): + return ddp_m(inputs) + + opt_outputs = opt_fn(inputs) + self.assertTrue(same(correct_outputs, opt_outputs)) + self.assertEqual(check_splits_compiler.compiler_called, 1) + + # TODO, debug this, regressed since initial development + @pytest.mark.skipif( + not hasattr(DDP, "_get_active_ddp_module"), + reason="requires pytorch landing in parallel", + ) + @pytest.mark.xfail + @patch.object(config, "optimize_ddp", True) + def test_aot_autograd(self): + """ + Explicitly check AotAutograd family of compilers work, + since they require example inputs propagated between graph splits. + """ + m, inputs, correct_outputs = self.get_model() + ddp_m = DDP(m, device_ids=[self.rank], bucket_cap_mb=25) + + @torchdynamo.optimize("aot_nvfuser") + def opt_fn(inputs): + return ddp_m(inputs) + + opt_outputs = opt_fn(inputs) + opt_outputs.sum().backward() + self.assertTrue(same(correct_outputs, opt_outputs)) diff --git a/torchdynamo/config.py b/torchdynamo/config.py index 00abba1478..71f7ac7447 100644 --- a/torchdynamo/config.py +++ b/torchdynamo/config.py @@ -125,6 +125,10 @@ # false_fn produces code with identical guards. enforce_cond_guards_match = True +# Automatically split model graph into pieces to match DDP bucket sizes +# to allow DDP comm/compute overlap +optimize_ddp = False + # If True, raises exception if TorchDynamo is called with a context manager raise_on_ctx_manager_usage = False diff --git a/torchdynamo/eval_frame.py b/torchdynamo/eval_frame.py index a9a518d110..d6760704ff 100644 --- a/torchdynamo/eval_frame.py +++ b/torchdynamo/eval_frame.py @@ -13,9 +13,11 @@ import torch import torch.utils._pytree as pytree from torch.fx.experimental.proxy_tensor import make_fx +from torch.nn.parallel.distributed import DistributedDataParallel import torchdynamo from torchdynamo.debug_utils import wrap_backend_debug +from torchdynamo.optimizations.distributed import DDPOptimizer from torchdynamo.utils import checkpoint_params from torchdynamo.utils import clone_inputs from torchdynamo.utils import compile_times @@ -231,6 +233,20 @@ def catch_errors(frame, cache_size): ): # nametuple constructor return None + if config.optimize_ddp: + ddp_module = DistributedDataParallel._get_active_ddp_module() + if ddp_module and frame.f_code.co_name == "forward": + with compile_lock: + ddp_optimizer = DDPOptimizer( + bucket_bytes_cap=ddp_module.bucket_bytes_cap, + parameters_to_ignore=ddp_module.parameters_to_ignore, + backend_compile_fn=callback._torchdynamo_orig_callable, + ) + hijacked_callback = convert_frame.convert_frame( + ddp_optimizer.compile_fn, guard_export_fn=None + ) + return hijacked_callback(frame, cache_size) + with compile_lock: return callback(frame, cache_size) except Exception: @@ -616,6 +632,12 @@ def patch(): if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer) ] + # disable dynamo for the wrapper that helps give dynamo hints about entering DDP + if hasattr(DistributedDataParallel, "_inside_ddp_forward"): + DistributedDataParallel._inside_ddp_forward = skip( + DistributedDataParallel._inside_ddp_forward + ) + # disable profile hook for opt in optimizers: opt._cuda_graph_capture_health_check = disable( diff --git a/torchdynamo/optimizations/distributed.py b/torchdynamo/optimizations/distributed.py new file mode 100644 index 0000000000..7598c28001 --- /dev/null +++ b/torchdynamo/optimizations/distributed.py @@ -0,0 +1,181 @@ +from typing import Any +from typing import List + +import torch +import torch.fx.traceback as fx_traceback +from torch import fx +from torch.fx.node import Node + + +def args_str(args): + # a debug helper + if torch.is_tensor(args): + return f"T[{args.shape}]" + elif isinstance(args, tuple): + return f"tuple({', '.join([args_str(x) for x in args])})" + elif isinstance(args, list): + return f"list({', '.join([args_str(x) for x in args])})" + else: + return str(args) + + +class DDPOptimizer: + def __init__( + self, + bucket_bytes_cap: int, + parameters_to_ignore: List[str], + backend_compile_fn, + debug=False, + ): + self.bucket_bytes_cap = bucket_bytes_cap + self.parameters_to_ignore = parameters_to_ignore + self.backend_compile_fn = backend_compile_fn + self.debug = debug + + def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]): + """ + TODO: + - handle params_and_buffers_to_ignore + - handle kwargs + """ + + # 1: compute the partition map according to DDP bucket logic + bucket_bytes = 0 + bucket_actual_sizes = [] + node_splits = [[]] + for node in reversed(gm.graph.nodes): + if bucket_bytes >= self.bucket_bytes_cap: + bucket_actual_sizes.insert(0, bucket_bytes) + bucket_bytes = 0 + node_splits.insert(0, []) + + if node.op == "output" or node.op == "placeholder": + continue + + elif node.op == "call_module": + target = gm.get_submodule(node.target) + params_size_b = sum( + [ + p.storage().nbytes() + for p in target.parameters() + if p.requires_grad + ] + ) + bucket_bytes += params_size_b + # print(f"accumulated {params_size_b} b from {node}") + else: + # TODO(whc) confirm this: + # (e.g. call_method, call_function aren't expected to 'have' parameters) + pass + + node_splits[0].append(node) + + if len(node_splits) == 1: + if self.debug: + print( + "DDPOptimizer did not split graphs." + f" Accumulated {bucket_bytes} bytes, and bucket cap is {self.bucket_bytes_cap}" + ) + return self.backend_compile_fn(gm, example_inputs) + + if len(bucket_actual_sizes) < len(node_splits): + bucket_actual_sizes.insert(0, bucket_bytes) + + if self.debug: + print( + f"DDPOptimizer used bucket cap {self.bucket_bytes_cap}" + f" and split graphs into parameter sizes {', '.join([str(b) for b in bucket_actual_sizes])}" + ) + + # 2: partition the graphmodule according to bucket capacity + partition_map = {} + for p, nodes in enumerate(node_splits): + for node in nodes: + partition_map[node] = p + + split_gm = fx.passes.split_module.split_module( + gm, None, lambda node: partition_map[node] + ) + if self.debug: + with open("debug_ddp_optimizer.log", "w") as dump_file: + dump_file.write("---orig graph---") + dump_file.write(str(gm.graph)) + dump_file.write("\n---split graph---") + dump_file.write(str(split_gm.graph)) + + # 3: compile each of the partitioned submodules using the user-provided compiler + class SubmodCompiler(torch.fx.interpreter.Interpreter): + def __init__(self, module, compiler, debug=False): + super().__init__(module) + self.compiler = compiler + self.debug = debug + + def compile_submod(self, submod, args, kwargs): + """ + Compile the submodule, + using a wrapper to make sure its output is always a tuple, + which is required by AotAutograd based compilers + """ + assert len(kwargs) == 0, "We assume only args for these modules" + + class WrapperModule(torch.nn.Module): + def __init__(self, compiled_submod, unwrap_singleton_tuple): + super().__init__() + self.compiled_submod = compiled_submod + self.unwrap_singleton_tuple = unwrap_singleton_tuple + + def forward(self, *args): + x = self.compiled_submod(*args) + # TODO(whc) + # for some reason the isinstance check is necessary if I split one node per submod + # - even though I supposedly wrapped the output in a tuple in those cases, the real + # compiled module was still returning a tensor + if self.unwrap_singleton_tuple and isinstance(x, tuple): + return x[0] + return x + + unwrap_singleton_tuple = False + for sn in submod.graph.nodes: + if sn.op == "output": + if not isinstance(sn.args[0], tuple): + unwrap_singleton_tuple = True + sn.args = (sn.args,) + wrapper = WrapperModule( + self.compiler(submod, args), + unwrap_singleton_tuple, + ) + return wrapper + + def run_node(self, n: Node) -> Any: + with fx_traceback.append_stack_trace(n.stack_trace): + args, kwargs = self.fetch_args_kwargs_from_env(n) + if self.debug: + print(f"run_node {n.op}, {n.target} got args {args_str(args)}") + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + # modify the currently running FX graph + # maybe this isn't sound in general, but only changing the target of a node might be ok? + if n.op == "call_module": + submod = self.fetch_attr(n.target) + if self.debug: + with open("debug_ddp_optimizer.log", "a") as dump_file: + dump_file.write(f"\n---{n.target} graph---") + dump_file.write(str(submod.graph)) + compiled_submod = self.compile_submod(submod, args, kwargs) + n.target = "compiled_" + n.target + self.module.delete_submodule(n.target) + self.module.add_submodule(n.target, compiled_submod) + + # then we execute the modified node using the usual logic + return getattr(self, n.op)(n.target, args, kwargs) + + submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, self.debug) + submod_compiler.run(*example_inputs) + + if self.debug: + with open("debug_ddp_optimizer.log", "a") as dump_file: + dump_file.write("\n---final graph---") + dump_file.write(str(split_gm.graph)) + + return split_gm