Skip to content

Commit

Permalink
[dynamo] fix dynamo + DTensor to work with 2d
Browse files Browse the repository at this point in the history
pair debugged with @wconstab and we found some issue in both dynamo and
the TP's fsdp extension side. This PR fixes the dynamo + DTensor integration
so that the current graph break FSDP can work with tensor parallel by moving
the torch.compile after FSDP wrapping.

ghstack-source-id: 1922a173c9b148a8a6bfe19e820bbebd531435dd
Pull Request resolved: #108329
  • Loading branch information
wanchaol committed Aug 31, 2023
1 parent e68b3ad commit 054abae
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 7 deletions.
7 changes: 3 additions & 4 deletions test/distributed/_tensor/test_dtensor_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,17 @@ def test_2d_fsdp_tp_compile(self):
tp_model2 = parallelize_module(
model_copy, twod_mesh, PairwiseParallel(), tp_mesh_dim=1
)
compiled_tp = torch.compile(tp_model2, backend="eager", fullgraph=True)

# TODO: now we first apply torch compile on tp model then use fsdp to wrap it, ideally
# we should apply torch.compile after fsdp wrap, but the current graph break approach
# have some issues with the tensor subclass compilation, need to dig into this later
compiled_2d = FSDP(
compiled_tp,
fsdp_2d = FSDP(
tp_model2,
process_group=fsdp_pg,
device_id=self.rank,
use_orig_params=True,
)

compiled_2d = torch.compile(fsdp_2d, backend="eager")
compiled_output = compiled_2d(inp)

self.assertEqual(out, compiled_output)
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/skipfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def _module_dir(m: types.ModuleType):
FILENAME_ALLOWLIST |= {
_module_dir(torch) + "distributed/tensor/parallel/_utils.py",
_module_dir(torch) + "distributed/tensor/parallel/style.py",
_module_dir(torch) + "distributed/tensor/parallel/_data_parallel_utils.py",
_module_dir(torch) + "distributed/_tensor/api.py",
_module_dir(torch) + "distributed/_tensor/device_mesh.py",
}
Expand Down
9 changes: 9 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from .distributed import (
DeviceMeshVariable,
PlacementClassVariable,
PlacementVariable,
ProcessGroupVariable,
)
from .functions import (
Expand Down Expand Up @@ -665,6 +666,14 @@ def index_source(key):
source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
)
elif PlacementVariable.is_placement(value):
# TODO: see if we need to add custom guard instead
# of a simple ID_MATCH
return PlacementVariable(
value,
source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
)
elif issubclass(type(value), type):
# TODO(whc) the following seems preferable but breaks some tests, debug
# elif inspect.isclass(value):
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def is_placement(value):

from torch.distributed._tensor.placement_types import Placement

return istype(value, Placement)
return isinstance(value, Placement)

def as_python_constant(self):
return self.value
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def get_state_from_generator():
elif is_from_local(self.value):
# rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
# and rewrite args to have only proxyable args, then insert call_function
# TODO: support cases where device_mesh + placements specified as kwargs
args_as_value = [x.as_python_constant() for x in args[1:]]

def fn_with_prim_types(x, **kwargs):
Expand Down
5 changes: 3 additions & 2 deletions torch/distributed/tensor/parallel/_data_parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
from torch.distributed._tensor import DTensor as DistributedTensor, Shard as DShard

from torch.distributed._tensor.placement_types import DTensorSpec

from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
Expand Down Expand Up @@ -151,8 +152,8 @@ def _flatten_tensor(
def _unflatten_tensor(tensor: torch.Tensor, spec: DTensorSpec) -> torch.Tensor:
result = DistributedTensor.from_local(
tensor,
device_mesh=spec.mesh,
placements=spec.placements,
spec.mesh,
spec.placements,
run_check=False,
)

Expand Down

0 comments on commit 054abae

Please sign in to comment.