Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions torchtitan/experiments/simple_fsdp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ def get_compile_backend(backend_name: str) -> Union[str, callable]:
# Perform auto optimization in aten fx-level and execute code in aot_eager backend
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend

from torch._inductor.config import aten_distributed_optimizations as dist_opts
from torch._inductor.fx_passes.overlap_scheduling import (
schedule_overlap_bucketing,
)

torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True
torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False
dist_opts.collective_bucketing = True
dist_opts.insert_overlap_deps = False
torch._inductor.config.allow_buffer_reuse = False

def aten_autobucketing_reordering_pass(
Expand Down
16 changes: 2 additions & 14 deletions torchtitan/experiments/simple_fsdp/simple_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Replicate,
Shard,
)
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.placement_types import _StridedShard, Placement
Expand Down Expand Up @@ -95,19 +95,7 @@ def _distribute_dtensor(
"""
inner_spec = tensor._spec
outer_mesh, inner_mesh = device_mesh, inner_spec.mesh
outer_global_mesh = _mesh_resources.get_root_mesh(outer_mesh)
inner_global_mesh = _mesh_resources.get_root_mesh(inner_mesh)
if outer_global_mesh != inner_global_mesh or (
outer_global_mesh is None or inner_global_mesh is None
):
raise AssertionError(
"Cannot distribute tensor across two meshes without the same root mesh: \n"
f"outer global mesh: {outer_global_mesh}\ninner global mesh: {inner_global_mesh}"
)
assert outer_mesh.mesh_dim_names is not None
assert inner_mesh.mesh_dim_names is not None
submesh_names = outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names
spanned_mesh = outer_global_mesh[submesh_names]
spanned_mesh = DeviceMesh._concatenate([outer_mesh, inner_mesh])

if len(dp_placements) == 1:
assert dp_placements[0].is_replicate() or dp_placements[0].is_shard()
Expand Down