Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,9 @@ class Experimental:

enable_simplefsdp_passes: bool = False

enable_autoparallel_asynctp: bool = False


@dataclass
class Validation:
enable: bool = False
Expand Down
27 changes: 25 additions & 2 deletions torchtitan/experiments/auto_parallel/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from autoparallel.api import AutoParallel

from torch.distributed import DeviceMesh
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.tensor.placement_types import Replicate, Shard

Expand All @@ -33,6 +32,7 @@ def parallelize_llama(
the model must fit on GPU or CPU memory.
"""
world_mesh = parallel_dims.world_mesh

def input_fn():
global_batch_size = job_config.training.global_batch_size
if global_batch_size < 0:
Expand Down Expand Up @@ -62,6 +62,27 @@ def input_fn():
lambda bucket_idx: 1000 / parallel_dims.tp
)

if job_config.experimental.enable_autoparallel_asynctp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

assert "tp" in world_mesh.mesh_dim_names
enable_symm_mem_for_group(world_mesh["tp"].get_group().group_name)
torch._inductor.config._micro_pipeline_tp = False
# Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork.
from autoparallel.asynctp import micro_pipeline_tp_pass

existing_post_grad_custom_post_pass = (
torch._inductor.config.post_grad_custom_post_pass
)

def _pass(graph):
if existing_post_grad_custom_post_pass is not None:
existing_post_grad_custom_post_pass(graph)

micro_pipeline_tp_pass(graph, None)

torch._inductor.config.post_grad_custom_post_pass = _pass

# bail out
# model = model_fn()
# return model
Expand All @@ -78,6 +99,7 @@ def input_fn():
world_mesh,
mp_policy=mp_policy,
compile=job_config.compile,
repeated_subgraphs=True,
) as autop:
autop.add_parameter_memory_constraint(low=None, high=None)

Expand All @@ -101,7 +123,8 @@ def input_fn():
)
out_sharding = x_sharding
loss_parallel_enabled = (
parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel
parallel_dims.tp_enabled
and not job_config.parallelism.disable_loss_parallel
)
if loss_parallel_enabled:
out_sharding = tuple(
Expand Down
Loading