diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 6861a7f1a8..2ec8b84e03 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -857,37 +857,11 @@ class Experimental: needs to ensure that the path can be imported. """ - # "none", "all", "only_fsdp" - bucket_all_gathers_fx: str = "none" - - # "none", "all" - bucket_reduce_scatters_fx: str = "none" - - reorder_for_compute_comm_overlap: bool = False - """ - Whether to enable inductor comm reordering passes - """ - - reorder_for_compute_comm_overlap_passes: list[str] = field( - default_factory=lambda: [ - "sink_waits_iterative", - "reorder_communication_preserving_peak_memory", - ] - ) - """ - Sequence of reordering passes (names of functions inside _inductor.comms) to call, - if reorder_for_compute_comm_overlap is enabled. - """ - - reorder_prefetch_limit: int | None = None - """ - How many ops to allow moving any individual collective, if 'reorder_communication_preserving_peak_memory' - pass is enabled. default of None means unlimited - """ + # "aten" (default), "inductor", "none" + comms_bucket_reorder_strategy: str = "aten" autop_force_bf16: bool = False - enable_simplefsdp_passes: bool = False @dataclass class Validation: diff --git a/torchtitan/train.py b/torchtitan/train.py index 7d492f7249..61ffe4ec1a 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -33,6 +33,7 @@ maybe_enable_memory_snapshot, maybe_enable_profiling, ) +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing class Trainer(torch.distributed.checkpoint.stateful.Stateful): @@ -122,46 +123,7 @@ def __init__(self, job_config: JobConfig): torch._inductor.config.allow_buffer_reuse = False # allow configuring inductor comms optimizations from torchtitan commandline - if job_config.experimental.enable_simplefsdp_passes: - # enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282) - from autoparallel.auto_bucketing import ( - simple_fsdp_autobucketing_reordering_pass, - simplefsdp_autobucketing_config, - ) - - torch._inductor.config.allow_buffer_reuse = False - torch._inductor.config.reorder_for_peak_memory = False - torch._inductor.config.reorder_for_compute_comm_overlap = True - simplefsdp_autobucketing_config.save_estimation_path = ( - "/tmp/torchtitan_simplefsdp_comm_estimation.pkl" - ) - simple_fsdp_autobucketing_reordering_pass = partial( - simple_fsdp_autobucketing_reordering_pass, - configs=simplefsdp_autobucketing_config, - ) - torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ - simple_fsdp_autobucketing_reordering_pass - ] - - # Don't use both sets of passes at the same time! - torch._inductor.config.bucket_all_gathers_fx = "none" - torch._inductor.config.bucket_reduce_scatters_fx = "none" - else: - torch._inductor.config.bucket_all_gathers_fx = ( - job_config.experimental.bucket_all_gathers_fx - ) - torch._inductor.config.bucket_reduce_scatters_fx = ( - job_config.experimental.bucket_reduce_scatters_fx - ) - torch._inductor.config.reorder_for_compute_comm_overlap = ( - job_config.experimental.reorder_for_compute_comm_overlap - ) - torch._inductor.config.reorder_for_compute_comm_overlap_passes = ( - job_config.experimental.reorder_for_compute_comm_overlap_passes - ) - torch._inductor.config.reorder_prefetch_limit = ( - job_config.experimental.reorder_prefetch_limit - ) + configure_inductor_for_autobucketing(job_config.experimental.comms_bucket_reorder_strategy) # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss).