Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#!/usr/bin/env pytest
import os
from unittest.mock import patch

import pytest
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP

import torchdynamo
from torchdynamo import config
from torchdynamo.testing import same


class ToyModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, num_hidden=2, out_feat=5):
super().__init__()
self.net = nn.Sequential(
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(5000, 5000), nn.ReLU()] * num_hidden
+ [nn.Linear(5000, 5), nn.ReLU()]
)

def forward(self, inputs):
return self.net(inputs)


class CheckSplitsCompiler:
def __init__(self):
self.compiler_called = 0

def compile_fn(self, gm, example_inputs):
self.compiler_called += 1
return gm


class TestDistributed(torchdynamo.testing.TestCase):
"""
Test harness initializes dist process group
"""

@classmethod
def setUpClass(cls):
super().setUpClass()
# _exit_stack is set up in TestCase
cls._exit_stack.enter_context(
patch.dict(
os.environ,
{
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12355",
},
)
)
cls.rank = 0
cls.device = f"cuda:{cls.rank}"
dist.init_process_group("gloo", rank=cls.rank, world_size=1)

@classmethod
def tearDownClass(cls):
dist.destroy_process_group()
super().tearDownClass()

def get_model(self):
m = ToyModel().to(self.device)
inputs = torch.randn(20, 10).to(self.device)
outputs = m(inputs)
return m, inputs, outputs

# fails with assertion in aot_autograd
@pytest.mark.xfail
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_aot_eager(self):
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=[self.rank])
ddp_m = torchdynamo.optimize("aot_eager")(ddp_m)
outputs = ddp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))

# segfaults
@pytest.mark.skip
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_inductor(self):
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=[self.rank])
ddp_m = torchdynamo.optimize("inductor")(ddp_m)
outputs = ddp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))

# fails with assertion in aot_autograd
@pytest.mark.xfail
@patch.object(config, "optimize_ddp", False)
def test_fsdp_baseline_aot_eager(self):
m, inputs, correct_outputs = self.get_model()
fsdp_m = FSDP(m, device_id=self.rank)
fsdp_m = torchdynamo.optimize("aot_eager")(fsdp_m)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))

# segfaults
@pytest.mark.skip
@patch.object(config, "optimize_ddp", False)
def test_fsdp_baseline_inductor(self):
m, inputs, correct_outputs = self.get_model()
fsdp_m = FSDP(m, device_id=self.rank)
fsdp_m = torchdynamo.optimize("inductor")(fsdp_m)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))

@pytest.mark.skipif(
not hasattr(DDP, "_get_active_ddp_module"),
reason="requires pytorch landing in parallel",
)
@patch.object(config, "optimize_ddp", True)
def test_graph_split(self):
"""
Just ensures that the appropriate number of splits happen (based on
bucket size and model parameters) - verifies the number of times
the user-provided compiler is called by the DDPOptimizer which is
doing the graph splitting
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=[self.rank], bucket_cap_mb=25)

check_splits_compiler = CheckSplitsCompiler()

@torchdynamo.optimize(check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)

opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)

@pytest.mark.skipif(
not hasattr(DDP, "_get_active_ddp_module"),
reason="requires pytorch landing in parallel",
)
@patch.object(config, "optimize_ddp", True)
def test_no_split(self):
"""
Ensures the DDPOptimizer returns a correct, compiled module without
introducing graph splits. (Based on model parmeters fitting in the bucket)
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=[self.rank], bucket_cap_mb=250)

check_splits_compiler = CheckSplitsCompiler()

@torchdynamo.optimize(check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)

opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 1)

# TODO, debug this, regressed since initial development
@pytest.mark.skipif(
not hasattr(DDP, "_get_active_ddp_module"),
reason="requires pytorch landing in parallel",
)
@pytest.mark.xfail
@patch.object(config, "optimize_ddp", True)
def test_aot_autograd(self):
"""
Explicitly check AotAutograd family of compilers work,
since they require example inputs propagated between graph splits.
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=[self.rank], bucket_cap_mb=25)

@torchdynamo.optimize("aot_nvfuser")
def opt_fn(inputs):
return ddp_m(inputs)

opt_outputs = opt_fn(inputs)
opt_outputs.sum().backward()
self.assertTrue(same(correct_outputs, opt_outputs))
4 changes: 4 additions & 0 deletions torchdynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@
# false_fn produces code with identical guards.
enforce_cond_guards_match = True

# Automatically split model graph into pieces to match DDP bucket sizes
# to allow DDP comm/compute overlap
optimize_ddp = False


# If True, raises exception if TorchDynamo is called with a context manager
raise_on_ctx_manager_usage = False
Expand Down
22 changes: 22 additions & 0 deletions torchdynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
import torch
import torch.utils._pytree as pytree
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn.parallel.distributed import DistributedDataParallel

import torchdynamo
from torchdynamo.debug_utils import wrap_backend_debug
from torchdynamo.optimizations.distributed import DDPOptimizer
from torchdynamo.utils import checkpoint_params
from torchdynamo.utils import clone_inputs
from torchdynamo.utils import compile_times
Expand Down Expand Up @@ -231,6 +233,20 @@ def catch_errors(frame, cache_size):
):
# nametuple constructor
return None
if config.optimize_ddp:
ddp_module = DistributedDataParallel._get_active_ddp_module()
if ddp_module and frame.f_code.co_name == "forward":
with compile_lock:
ddp_optimizer = DDPOptimizer(
bucket_bytes_cap=ddp_module.bucket_bytes_cap,
parameters_to_ignore=ddp_module.parameters_to_ignore,
backend_compile_fn=callback._torchdynamo_orig_callable,
)
hijacked_callback = convert_frame.convert_frame(
ddp_optimizer.compile_fn, guard_export_fn=None
)
return hijacked_callback(frame, cache_size)

with compile_lock:
return callback(frame, cache_size)
except Exception:
Expand Down Expand Up @@ -616,6 +632,12 @@ def patch():
if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
]

# disable dynamo for the wrapper that helps give dynamo hints about entering DDP
if hasattr(DistributedDataParallel, "_inside_ddp_forward"):
DistributedDataParallel._inside_ddp_forward = skip(
DistributedDataParallel._inside_ddp_forward
)

# disable profile hook
for opt in optimizers:
opt._cuda_graph_capture_health_check = disable(
Expand Down
Loading