diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 8571e5680c..94dddaa9d0 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -739,6 +739,9 @@ class Experimental: enable_simplefsdp_passes: bool = False + enable_autoparallel_asynctp: bool = False + + @dataclass class Validation: enable: bool = False diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 6648f29ab8..17d53a56d1 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -10,7 +10,6 @@ from autoparallel.api import AutoParallel -from torch.distributed import DeviceMesh from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard @@ -33,6 +32,7 @@ def parallelize_llama( the model must fit on GPU or CPU memory. """ world_mesh = parallel_dims.world_mesh + def input_fn(): global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: @@ -62,6 +62,27 @@ def input_fn(): lambda bucket_idx: 1000 / parallel_dims.tp ) + if job_config.experimental.enable_autoparallel_asynctp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + assert "tp" in world_mesh.mesh_dim_names + enable_symm_mem_for_group(world_mesh["tp"].get_group().group_name) + torch._inductor.config._micro_pipeline_tp = False + # Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork. + from autoparallel.asynctp import micro_pipeline_tp_pass + + existing_post_grad_custom_post_pass = ( + torch._inductor.config.post_grad_custom_post_pass + ) + + def _pass(graph): + if existing_post_grad_custom_post_pass is not None: + existing_post_grad_custom_post_pass(graph) + + micro_pipeline_tp_pass(graph, None) + + torch._inductor.config.post_grad_custom_post_pass = _pass + # bail out # model = model_fn() # return model @@ -78,6 +99,7 @@ def input_fn(): world_mesh, mp_policy=mp_policy, compile=job_config.compile, + repeated_subgraphs=True, ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) @@ -101,7 +123,8 @@ def input_fn(): ) out_sharding = x_sharding loss_parallel_enabled = ( - parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel ) if loss_parallel_enabled: out_sharding = tuple(