From 3ccd12c87667e93dbacd72c0da24060c85421c96 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 12 Jun 2025 21:08:11 -0700 Subject: [PATCH 01/25] [WIP] Integrate autoparallel into torchtitan TODO - try converting model params into fake tensors - figure out init fn - integrate torchtitan configs for DP/TP to control autop Hack an init_fn for llama3 and observe loss decreasing with autoparallel """ [rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step: 1 loss: 8.1880 memory: 4.88GiB(6.16%) tps: 28 [rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step: 2 loss: 8.1610 memory: 4.90GiB(6.20%) tps: 13,785 [rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step: 3 loss: 8.0871 memory: 4.90GiB(6.20%) tps: 14,006 [rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step: 4 loss: 7.9516 memory: 4.90GiB(6.20%) tps: 13,770 [rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step: 5 loss: 7.8552 memory: 4.90GiB(6.20%) tps: 13,959 [rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step: 6 loss: 7.7732 memory: 4.90GiB(6.20%) tps: 13,859 [rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step: 7 loss: 7.6987 memory: 4.90GiB(6.20%) tps: 13,664 [rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step: 8 loss: 7.6779 memory: 4.90GiB(6.20%) tps: 13,985 [rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step: 9 loss: 7.6043 memory: 4.90GiB(6.20%) tps: 13,962 [rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10 loss: 7.5778 memory: 4.90GiB(6.20%) tps: 13,891 """ Adopt new autoparallel API with meta-init model Allows reverting a lot of the hacks in the original integration that were caused by not creating a model obj in the train.py due to passing a model_fn builder to autop. Fixes to align with latest autoparallel Add inductor config knobs for comms optimizations to torchtitan Make inductor always run compile passes basically, this is an annoying workaround for debugging iteratively. 1- you run the model, it compiles, but something weird happens 2- you enable some logging or tlparse, rerun. but inductor decides not to run your pass anymore, its results are cached. since (2) has confused me horribly on more than one occasion, i just disable caching for now Drop hacky llama3_init_fn and use autop init_weights feature Relying on https://github.com/pytorch-labs/autoparallel/pull/20, this lets us automatically apply a user's init_weights fn to the autoparallel model. Verified this works with `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4` ``` [rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step: 1 loss: 8.1848 memory: 1.09GiB(1.14%) tps: 77 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step: 2 loss: 8.1619 memory: 1.15GiB(1.21%) tps: 48,138 tflops: 3.46 mfu: 0.35 % [rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step: 3 loss: 8.1140 memory: 1.15GiB(1.21%) tps: 88,440 tflops: 6.36 mfu: 0.64 % [rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step: 4 loss: 8.0099 memory: 1.15GiB(1.21%) tps: 82,626 tflops: 5.94 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step: 5 loss: 7.8928 memory: 1.15GiB(1.21%) tps: 81,594 tflops: 5.87 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step: 6 loss: 7.7758 memory: 1.15GiB(1.21%) tps: 79,607 tflops: 5.72 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step: 7 loss: 7.6221 memory: 1.15GiB(1.21%) tps: 81,448 tflops: 5.86 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step: 8 loss: 7.5578 memory: 1.15GiB(1.21%) tps: 79,732 tflops: 5.73 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step: 9 loss: 7.3851 memory: 1.15GiB(1.21%) tps: 85,655 tflops: 6.16 mfu: 0.62 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10 loss: 7.3361 memory: 1.15GiB(1.21%) tps: 81,855 tflops: 5.89 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete ``` fix lint --- torchtitan/config_manager.py | 22 ++++++ torchtitan/experiments/__init__.py | 1 + .../experiments/auto_parallel/README.md | 7 ++ .../experiments/auto_parallel/__init__.py | 31 ++++++++ .../auto_parallel/parallelize_llama.py | 77 +++++++++++++++++++ torchtitan/train.py | 27 +++++-- 6 files changed, 159 insertions(+), 6 deletions(-) create mode 100644 torchtitan/experiments/auto_parallel/README.md create mode 100644 torchtitan/experiments/auto_parallel/__init__.py create mode 100644 torchtitan/experiments/auto_parallel/parallelize_llama.py diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 5f1a1e8b7f..1a45a7800c 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -664,6 +664,28 @@ class Experimental: needs to ensure that the path can be imported. """ + reorder_for_compute_comm_overlap: bool = False + """ + Whether to enable inductor comm reordering passes + """ + + reorder_for_compute_comm_overlap_passes: list[str] = field( + default_factory=lambda: [ + "sink_waits", + "reorder_communication_preserving_peak_memory", + ] + ) + """ + Sequence of reordering passes (names of functions inside _inductor.comms) to call, + if reorder_for_compute_comm_overlap is enabled. + """ + + reorder_prefetch_limit: int | None = None + """ + How many ops to allow moving any individual collective, if 'reorder_communication_preserving_peak_memory' + pass is enabled. default of None means unlimited + """ + @dataclass class Validation: diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 4c54bdc13e..b7ff983e97 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -4,5 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torchtitan.experiments.auto_parallel # noqa: F401 import torchtitan.experiments.llama4 # noqa: F401 import torchtitan.experiments.simple_fsdp # noqa: F401 diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md new file mode 100644 index 0000000000..ef66a59166 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/README.md @@ -0,0 +1,7 @@ +## Auto Parallel + +requires installing git@github.com:pytorch-labs/autoparallel.git + +`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` + +(or llama3-8b.toml) diff --git a/torchtitan/experiments/auto_parallel/__init__.py b/torchtitan/experiments/auto_parallel/__init__.py new file mode 100644 index 0000000000..8f5a876b4e --- /dev/null +++ b/torchtitan/experiments/auto_parallel/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec +from .parallelize_llama import parallelize_llama + +register_train_spec( + TrainSpec( + name="llama3_auto_parallel", + cls=Transformer, + config=llama3_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py new file mode 100644 index 0000000000..bb7f1204df --- /dev/null +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch + +from autoparallel.api import AutoParallel + +from torch.distributed import DeviceMesh +from torch.distributed.tensor.placement_types import Replicate, Shard + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +def parallelize_llama( + model, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = world_mesh["dp"].size() + global_batch_size = job_config.training.local_batch_size * dp_degree + return torch.rand( + (global_batch_size, job_config.training.seq_len), device="cuda" + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert ( + len(world_mesh.shape) == 2 + ), "Only support 2D mesh (DP, TP) for now- OK if one has size=1" + assert parallel_dims.dp_shard_enabled is True, "DDP not supported yet" + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + # bail out + # model = model_fn() + # return model + + autop = AutoParallel(model, input_fn, world_mesh) + autop.add_parameter_memory_constraint(low=None, high=None) + + x_sharding = (Shard(0), Replicate()) + + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + if job_config.training.compile: + torch._inductor.config.reorder_for_peak_memory = False + parallel_mod = torch.compile(parallel_mod, fullgraph=True) + + return parallel_mod diff --git a/torchtitan/train.py b/torchtitan/train.py index 3dc8a61b28..9626d8a5a4 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,6 +12,7 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.tensor import DTensor import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module @@ -116,6 +117,21 @@ def __init__(self, job_config: JobConfig): gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) + # TODO(whc) + # I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + + # allow configuring inductor comms optimizations from torchtitan commandline + torch._inductor.config.reorder_for_compute_comm_overlap = ( + job_config.experimental.reorder_for_compute_comm_overlap + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = ( + job_config.experimental.reorder_for_compute_comm_overlap_passes + ) + torch._inductor.config.reorder_prefetch_limit = ( + job_config.experimental.reorder_prefetch_limit + ) + # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( @@ -141,20 +157,19 @@ def __init__(self, job_config: JobConfig): ) # build model (using meta init) - model_cls = self.train_spec.cls model_args = self.train_spec.config[job_config.model.flavor] + model_cls = self.train_spec.cls # set the model args from training job configs model_args.update_from_config(job_config, tokenizer) - logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) + with torch.device("meta"): model = model_cls(model_args) - - # Build the collection of model converters. No-op if `model.converters` empty - model_converters = build_model_converters(job_config, parallel_dims) - model_converters.convert(model) + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) # metrics logging build_metrics_processor_fn = ( From e6d2cafc4b33868a49af2e9c5bd6563994f4fa64 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 26 Jun 2025 18:34:34 -0700 Subject: [PATCH 02/25] Autoparallel support for DP-only, DP+TP, or TP-only lets existing torchtitan knobs which govern DP/TP mesh creation and mesh size influence the sharding constraints of autoparallel, allowing it to support these different sharding configurations. --- .../auto_parallel/parallelize_llama.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index bb7f1204df..6e0d9b4dcb 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -38,17 +38,13 @@ def input_fn(): if global_batch_size < 0: # This global batch size results in 1 gradient accumulation # step. - dp_degree = world_mesh["dp"].size() + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard global_batch_size = job_config.training.local_batch_size * dp_degree return torch.rand( (global_batch_size, job_config.training.seq_len), device="cuda" ) # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP - assert ( - len(world_mesh.shape) == 2 - ), "Only support 2D mesh (DP, TP) for now- OK if one has size=1" - assert parallel_dims.dp_shard_enabled is True, "DDP not supported yet" assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" assert parallel_dims.cp_enabled is False, "CP not supported yet" assert parallel_dims.pp_enabled is False, "PP not supported yet" @@ -60,8 +56,18 @@ def input_fn(): autop = AutoParallel(model, input_fn, world_mesh) autop.add_parameter_memory_constraint(low=None, high=None) - x_sharding = (Shard(0), Replicate()) - + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + assert all( + name in possible_input_shardings for name in world_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in world_mesh.mesh_dim_names + ) autop.add_input_constraints([x_sharding]) autop.add_output_constraints([x_sharding]) t0 = time.time() From 68476b3fea19dda14f760ea246aac85194c37b39 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 25 Jul 2025 09:25:24 -0700 Subject: [PATCH 03/25] Update CLI inductor configs for bucketing/reordering --- torchtitan/config_manager.py | 8 +++++++- torchtitan/train.py | 8 ++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1a45a7800c..44567a0fdb 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -664,6 +664,12 @@ class Experimental: needs to ensure that the path can be imported. """ + # "none", "all", "only_fsdp" + bucket_all_gathers_fx: str | None = None + + # "none", "all" + bucket_reduce_scatters_fx: str | None = None + reorder_for_compute_comm_overlap: bool = False """ Whether to enable inductor comm reordering passes @@ -671,7 +677,7 @@ class Experimental: reorder_for_compute_comm_overlap_passes: list[str] = field( default_factory=lambda: [ - "sink_waits", + "sink_waits_iterative", "reorder_communication_preserving_peak_memory", ] ) diff --git a/torchtitan/train.py b/torchtitan/train.py index 9626d8a5a4..8ca90037b0 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -120,8 +120,16 @@ def __init__(self, job_config: JobConfig): # TODO(whc) # I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False # allow configuring inductor comms optimizations from torchtitan commandline + torch._inductor.config.bucket_all_gathers_fx = ( + job_config.experimental.bucket_all_gathers_fx + ) + torch._inductor.config.bucket_reduce_scatters_fx = ( + job_config.experimental.bucket_reduce_scatters_fx + ) torch._inductor.config.reorder_for_compute_comm_overlap = ( job_config.experimental.reorder_for_compute_comm_overlap ) From 9ee9f75bfd4a009da580922d6cd07cf276db85c9 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 25 Jul 2025 12:07:17 -0700 Subject: [PATCH 04/25] add back llama3_autoparallel_init_fn --- torchtitan/train.py | 75 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 8ca90037b0..35435eecab 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -238,6 +238,78 @@ def __init__(self, job_config: JobConfig): self.loss_fn, self.gradient_accumulation_steps ) + def llama3_autoparallel_init_fn(model): + # WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying + # code from the llama3 init_weights functions throughout the model components, and adjusting them to use + # the new FQN structures in autoparallel. + # TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module + def param(name): + return model.get_parameter(f"params.{name}") + + from torchtitan.models.llama3.model.model import precompute_freqs_cis + + model.buffers_.get_buffer("freqs_cis").copy_( + DTensor.from_local( + precompute_freqs_cis( + model_args.dim // model_args.n_heads, + model_args.max_seq_len, + model_args.rope_theta, + ), + device_mesh=model.buffers_.get_buffer("freqs_cis").device_mesh, + ) + ) + + torch.nn.init.normal_(param("tok_embeddings/weight")) + + def init_layer(i): + for norm in ("attention_norm", "ffn_norm"): + torch.nn.init.ones_(param(f"layers/{i}/{norm}/weight")) + + if model_args.depth_init: + weight_init_std = 0.02 / (2 * (i + 1)) ** 0.5 + else: + weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + for linear in ("wq", "wk", "wv"): + torch.nn.init.trunc_normal_( + param(f"layers/{i}/attention/{linear}/weight"), + mean=0.0, + std=0.02, + ) + torch.nn.init.trunc_normal_( + param(f"layers/{i}/attention/wo/weight"), + mean=0.0, + std=weight_init_std, + ) + + torch.nn.init.trunc_normal_( + param(f"layers/{i}/feed_forward/w1/weight"), mean=0.0, std=0.02 + ) + for linear in ("w2", "w3"): + torch.nn.init.trunc_normal_( + param(f"layers/{i}/feed_forward/{linear}/weight"), + mean=0.0, + std=weight_init_std, + ) + + for i in range(model_args.n_layers): + init_layer(i) + + if param("norm/weight") is not None: + torch.nn.init.ones_(param("norm/weight")) + + final_out_std = model_args.dim**-0.5 + cutoff_factor = 3 + + if param("output/weight") is not None: + torch.nn.init.trunc_normal_( + param("output/weight"), + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + # apply parallelisms and initialization if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: @@ -282,7 +354,8 @@ def __init__(self, job_config: JobConfig): model.to_empty(device=init_device) with torch.no_grad(): - model.init_weights(buffer_device=buffer_device) + # model.init_weights(buffer_device=buffer_device) + llama3_autoparallel_init_fn(model) model.train() self.model_parts = [model] From f6e4099cfa36f00a2b654fef2483ebc55d1a738c Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 27 Jul 2025 19:53:10 -0700 Subject: [PATCH 05/25] Track API change from new AOTAutograd interface Signed-off-by: Edward Z. Yang --- .../auto_parallel/parallelize_llama.py | 44 +++++++++---------- torchtitan/train.py | 6 +-- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 6e0d9b4dcb..4b0e7a3e03 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -53,28 +53,28 @@ def input_fn(): # model = model_fn() # return model - autop = AutoParallel(model, input_fn, world_mesh) - autop.add_parameter_memory_constraint(low=None, high=None) - - possible_input_shardings = { - # maps relative to mesh dim names used in torchtitan - "dp_replicate": Shard(0), - "dp_shard": Shard(0), - "tp": Replicate(), - } - assert all( - name in possible_input_shardings for name in world_mesh.mesh_dim_names - ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" - x_sharding = tuple( - possible_input_shardings[name] for name in world_mesh.mesh_dim_names - ) - autop.add_input_constraints([x_sharding]) - autop.add_output_constraints([x_sharding]) - t0 = time.time() - sharding_placement = autop.optimize_placement() - t1 = time.time() - logger.info(f"AutoParallel took {t1 - t0} seconds") - parallel_mod = autop.apply_placement(sharding_placement) + with AutoParallel(model, input_fn, world_mesh) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + assert all( + name in possible_input_shardings for name in world_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in world_mesh.mesh_dim_names + ) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) if job_config.training.compile: torch._inductor.config.reorder_for_peak_memory = False diff --git a/torchtitan/train.py b/torchtitan/train.py index 35435eecab..cd85e27045 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -244,18 +244,18 @@ def llama3_autoparallel_init_fn(model): # the new FQN structures in autoparallel. # TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module def param(name): - return model.get_parameter(f"params.{name}") + return model.get_parameter(f"{name.replace('/', '.')}") from torchtitan.models.llama3.model.model import precompute_freqs_cis - model.buffers_.get_buffer("freqs_cis").copy_( + model.get_buffer("freqs_cis").copy_( DTensor.from_local( precompute_freqs_cis( model_args.dim // model_args.n_heads, model_args.max_seq_len, model_args.rope_theta, ), - device_mesh=model.buffers_.get_buffer("freqs_cis").device_mesh, + device_mesh=model.get_buffer("freqs_cis").device_mesh, ) ) From 4d7ee8a0b864525e8194876d4d91ea2f6e9be743 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 28 Jul 2025 14:30:49 -0700 Subject: [PATCH 06/25] Support forcing the model into bf16 for perf debugging --- torchtitan/config_manager.py | 1 + torchtitan/experiments/auto_parallel/parallelize_llama.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 44567a0fdb..ee7de65105 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -692,6 +692,7 @@ class Experimental: pass is enabled. default of None means unlimited """ + autop_force_bf16: bool = False @dataclass class Validation: diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 4b0e7a3e03..a8f7b217cd 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -52,6 +52,9 @@ def input_fn(): # bail out # model = model_fn() # return model + if job_config.experimental.autop_force_bf16: + logger.info("Forcing bf16 on model") + model = model.bfloat16() with AutoParallel(model, input_fn, world_mesh) as autop: autop.add_parameter_memory_constraint(low=None, high=None) From b801d0b5b307617e25e9563694397673cb3041ee Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 29 Jul 2025 10:56:50 -0700 Subject: [PATCH 07/25] Integrate MixedPrecision with AutoParallel and fix example_inputs --- .../auto_parallel/parallelize_llama.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index a8f7b217cd..bb03e81c95 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -11,9 +11,10 @@ from autoparallel.api import AutoParallel from torch.distributed import DeviceMesh +from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger @@ -40,9 +41,13 @@ def input_fn(): # step. dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard global_batch_size = job_config.training.local_batch_size * dp_degree - return torch.rand( - (global_batch_size, job_config.training.seq_len), device="cuda" - ) + return torch.randint( + 0, + # job_config.training.vocab_size, + model.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" @@ -56,7 +61,10 @@ def input_fn(): logger.info("Forcing bf16 on model") model = model.bfloat16() - with AutoParallel(model, input_fn, world_mesh) as autop: + param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] + reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + with AutoParallel(model, input_fn, world_mesh, mp_policy=mp_policy) as autop: autop.add_parameter_memory_constraint(low=None, high=None) possible_input_shardings = { From b099cf9c06b11084b01645a7878ff8ee91d82ef3 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 29 Jul 2025 13:34:16 -0700 Subject: [PATCH 08/25] Use in-place compile API Signed-off-by: Edward Z. Yang --- torchtitan/experiments/auto_parallel/parallelize_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index bb03e81c95..39ca4ab1da 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -89,6 +89,6 @@ def input_fn(): if job_config.training.compile: torch._inductor.config.reorder_for_peak_memory = False - parallel_mod = torch.compile(parallel_mod, fullgraph=True) + parallel_mod.compile(fullgraph=True) return parallel_mod From b3587d9fec06fee75fd5398673831a95a08b083c Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 29 Jul 2025 13:14:05 -0700 Subject: [PATCH 09/25] Fix bucketing pass configs - fix passing of "none" (not None) to control bucketing passes --- torchtitan/config_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index ee7de65105..131d7d2bd8 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -665,10 +665,10 @@ class Experimental: """ # "none", "all", "only_fsdp" - bucket_all_gathers_fx: str | None = None + bucket_all_gathers_fx: str = "none" # "none", "all" - bucket_reduce_scatters_fx: str | None = None + bucket_reduce_scatters_fx: str = "none" reorder_for_compute_comm_overlap: bool = False """ From 42c2c07613d0239c5b630bcfda534db2ffd66ab1 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 30 Jul 2025 13:06:14 -0700 Subject: [PATCH 10/25] Support both eager and autoparallel init based on model.name --- torchtitan/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index cd85e27045..793957a761 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -354,8 +354,10 @@ def init_layer(i): model.to_empty(device=init_device) with torch.no_grad(): - # model.init_weights(buffer_device=buffer_device) - llama3_autoparallel_init_fn(model) + if job_config.model.name == "llama3_auto_parallel": + llama3_autoparallel_init_fn(model) + else: + model.init_weights(buffer_device=buffer_device) model.train() self.model_parts = [model] From d93845e57e88c2fa9dc7eeda6e754129eb01660f Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 6 Aug 2025 11:57:11 -0700 Subject: [PATCH 11/25] Remove llama3 init weights hack since https://github.com/meta-pytorch/autoparallel/pull/52 landed --- torchtitan/train.py | 77 +-------------------------------------------- 1 file changed, 1 insertion(+), 76 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 793957a761..8ca90037b0 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -238,78 +238,6 @@ def __init__(self, job_config: JobConfig): self.loss_fn, self.gradient_accumulation_steps ) - def llama3_autoparallel_init_fn(model): - # WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying - # code from the llama3 init_weights functions throughout the model components, and adjusting them to use - # the new FQN structures in autoparallel. - # TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module - def param(name): - return model.get_parameter(f"{name.replace('/', '.')}") - - from torchtitan.models.llama3.model.model import precompute_freqs_cis - - model.get_buffer("freqs_cis").copy_( - DTensor.from_local( - precompute_freqs_cis( - model_args.dim // model_args.n_heads, - model_args.max_seq_len, - model_args.rope_theta, - ), - device_mesh=model.get_buffer("freqs_cis").device_mesh, - ) - ) - - torch.nn.init.normal_(param("tok_embeddings/weight")) - - def init_layer(i): - for norm in ("attention_norm", "ffn_norm"): - torch.nn.init.ones_(param(f"layers/{i}/{norm}/weight")) - - if model_args.depth_init: - weight_init_std = 0.02 / (2 * (i + 1)) ** 0.5 - else: - weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 - - for linear in ("wq", "wk", "wv"): - torch.nn.init.trunc_normal_( - param(f"layers/{i}/attention/{linear}/weight"), - mean=0.0, - std=0.02, - ) - torch.nn.init.trunc_normal_( - param(f"layers/{i}/attention/wo/weight"), - mean=0.0, - std=weight_init_std, - ) - - torch.nn.init.trunc_normal_( - param(f"layers/{i}/feed_forward/w1/weight"), mean=0.0, std=0.02 - ) - for linear in ("w2", "w3"): - torch.nn.init.trunc_normal_( - param(f"layers/{i}/feed_forward/{linear}/weight"), - mean=0.0, - std=weight_init_std, - ) - - for i in range(model_args.n_layers): - init_layer(i) - - if param("norm/weight") is not None: - torch.nn.init.ones_(param("norm/weight")) - - final_out_std = model_args.dim**-0.5 - cutoff_factor = 3 - - if param("output/weight") is not None: - torch.nn.init.trunc_normal_( - param("output/weight"), - mean=0.0, - std=final_out_std, - a=-cutoff_factor * final_out_std, - b=cutoff_factor * final_out_std, - ) - # apply parallelisms and initialization if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: @@ -354,10 +282,7 @@ def init_layer(i): model.to_empty(device=init_device) with torch.no_grad(): - if job_config.model.name == "llama3_auto_parallel": - llama3_autoparallel_init_fn(model) - else: - model.init_weights(buffer_device=buffer_device) + model.init_weights(buffer_device=buffer_device) model.train() self.model_parts = [model] From 60f5f118f4c17747bb117d4a8018eec180e0bd89 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 7 Aug 2025 15:25:02 -0700 Subject: [PATCH 12/25] Print profiling manifold url prints an (internal, vpn) only link for each profile trace file that's saved to manifold. Just search for 'trace' in your job logs on mast, and click one of the rank links. e.g. [trainer37|5]:[titan] 2025-08-07 14:21:01,227 - root - INFO - Finished dumping profiler traces in 5.22 seconds: https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/torchtrain_datasets/tree/outputs/torchtitan-64-whc-jv2j4mp/profile_trace/iteration_20/rank37_trace.json --- torchtitan/tools/profiling.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 050b992cc8..7c822c3b66 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -19,7 +19,9 @@ # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 - +PERFETTO_UI_ROOT_URL = ( + "https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html" +) @contextlib.contextmanager def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): @@ -42,10 +44,22 @@ def trace_handler(prof): logger.info(f"Dumping profiler traces at step {prof.step_num}") begin = time.monotonic() - prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") - logger.info( - f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" - ) + filename = f"{curr_trace_dir}/rank{rank}_trace.json" + + prof.export_chrome_trace(filename) + log_str = f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" + # not directly landable on upstream titan, + # but conveniently prints the internal url for perfetto on manifold for mast jobs + manifold_mount_prefix = "/mnt/mffuse/" + if filename.find(manifold_mount_prefix) == 0: + manifold_path = os.path.join("torchtrain_datasets/tree", filename.split(manifold_mount_prefix)[1]) + perfetto_url = ( + PERFETTO_UI_ROOT_URL + + "#!/?url=https://interncache-all.fbcdn.net/manifold/" + + manifold_path + ) + log_str += f": {perfetto_url}" + logger.info(log_str) logger.info(f"Profiling active. Traces will be saved at {trace_dir}") From 6c782eba53b51ee2f95af0be98f49c443ff9b52f Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 7 Aug 2025 20:07:39 -0700 Subject: [PATCH 13/25] Support new compile API from autoparallel PR #77 --- torchtitan/experiments/auto_parallel/parallelize_llama.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 39ca4ab1da..2d3a3e2e2c 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -64,7 +64,7 @@ def input_fn(): param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - with AutoParallel(model, input_fn, world_mesh, mp_policy=mp_policy) as autop: + with AutoParallel(model, input_fn, world_mesh, mp_policy=mp_policy, compile=job_config.training.compile) as autop: autop.add_parameter_memory_constraint(low=None, high=None) possible_input_shardings = { @@ -87,8 +87,4 @@ def input_fn(): logger.info(f"AutoParallel took {t1 - t0} seconds") parallel_mod = autop.apply_placement(sharding_placement) - if job_config.training.compile: - torch._inductor.config.reorder_for_peak_memory = False - parallel_mod.compile(fullgraph=True) - return parallel_mod From 4712163eb9d0ac6709711de8aa708abd7f8d38d3 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 8 Aug 2025 15:29:37 +0200 Subject: [PATCH 14/25] Fix bucket sizes for AutoParallel 1D (#1545) This PR makes bucket sizes for all-gather and reduce-scatter to be of the same size for 1d FSDP. --- .../auto_parallel/parallelize_llama.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 2d3a3e2e2c..d003cb2e5f 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -41,19 +41,28 @@ def input_fn(): # step. dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard global_batch_size = job_config.training.local_batch_size * dp_degree - return torch.randint( - 0, - # job_config.training.vocab_size, - model.vocab_size, - (global_batch_size, job_config.training.seq_len), - device=torch.device("cuda"), - ), + return ( + torch.randint( + 0, + # job_config.training.vocab_size, + model.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" assert parallel_dims.cp_enabled is False, "CP not supported yet" assert parallel_dims.pp_enabled is False, "PP not supported yet" + torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( + lambda bucket_idx: 500 / parallel_dims.tp + ) + torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = ( + lambda bucket_idx: 1000 / parallel_dims.tp + ) + # bail out # model = model_fn() # return model @@ -64,7 +73,13 @@ def input_fn(): param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - with AutoParallel(model, input_fn, world_mesh, mp_policy=mp_policy, compile=job_config.training.compile) as autop: + with AutoParallel( + model, + input_fn, + world_mesh, + mp_policy=mp_policy, + compile=job_config.training.compile, + ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) possible_input_shardings = { From 3f04d223f7f3bd98510c8e288ff72183ff6caa5b Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sun, 10 Aug 2025 19:12:30 +0200 Subject: [PATCH 15/25] Add support for loss parallel (#1546) IMO we should just add the loss in the model and let autoparallel parallelize it for us. But for now, let's follow how the other models are implemented --- .../auto_parallel/parallelize_llama.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index d003cb2e5f..49a8bc49ff 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -88,18 +88,48 @@ def input_fn(): "dp_shard": Shard(0), "tp": Replicate(), } + # only used if loss parallel is enabled + possible_output_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_shard": Shard(0), + "tp": Shard(2), + } assert all( name in possible_input_shardings for name in world_mesh.mesh_dim_names ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" x_sharding = tuple( possible_input_shardings[name] for name in world_mesh.mesh_dim_names ) + out_sharding = x_sharding + if parallel_dims.loss_parallel_enabled: + out_sharding = tuple( + possible_output_shardings[name] + for name in world_mesh.mesh_dim_names + if name != "dp_replicate" + ) autop.add_input_constraints([x_sharding]) - autop.add_output_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) t0 = time.time() sharding_placement = autop.optimize_placement() t1 = time.time() logger.info(f"AutoParallel took {t1 - t0} seconds") parallel_mod = autop.apply_placement(sharding_placement) + if parallel_dims.loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, world_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + return parallel_mod From 8e50870985239adb109733e599f389350706b03b Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 18 Aug 2025 15:51:30 -0700 Subject: [PATCH 16/25] Add config for running simple-fsdp bucketing/reordering passes just add `--experimental.enable_simplefsdp_passes` and do not try to combine it with other `bucket_*` or `reorder_*` options. --- torchtitan/config_manager.py | 2 ++ torchtitan/train.py | 58 ++++++++++++++++++++++++++---------- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 131d7d2bd8..efc9d7a951 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -694,6 +694,8 @@ class Experimental: autop_force_bf16: bool = False + enable_simplefsdp_passes: bool = False + @dataclass class Validation: enabled: bool = False diff --git a/torchtitan/train.py b/torchtitan/train.py index 8ca90037b0..1ebdd984ed 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -124,21 +124,49 @@ def __init__(self, job_config: JobConfig): torch._inductor.config.allow_buffer_reuse = False # allow configuring inductor comms optimizations from torchtitan commandline - torch._inductor.config.bucket_all_gathers_fx = ( - job_config.experimental.bucket_all_gathers_fx - ) - torch._inductor.config.bucket_reduce_scatters_fx = ( - job_config.experimental.bucket_reduce_scatters_fx - ) - torch._inductor.config.reorder_for_compute_comm_overlap = ( - job_config.experimental.reorder_for_compute_comm_overlap - ) - torch._inductor.config.reorder_for_compute_comm_overlap_passes = ( - job_config.experimental.reorder_for_compute_comm_overlap_passes - ) - torch._inductor.config.reorder_prefetch_limit = ( - job_config.experimental.reorder_prefetch_limit - ) + if job_config.experimental.enable_simplefsdp_passes: + try: + from torch._inductor.simple_fsdp.bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir + except ImportError: + print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c") + raise + + # Configs from Ruisi + + # set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives + torch._inductor.config.simplefsdp.relax_ratio = 0 + torch._inductor.config.allow_buffer_reuse = False + torch._inductor.config.simplefsdp.estimate_ir = False + torch._inductor.config.simplefsdp.estimate_verbose = False + torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl" + # set to True after the first communication estimation results are saved. This would reduce decision making time. + torch._inductor.config.simplefsdp.load_cache = False + torch._inductor.config.simplefsdp.enable_bucket_ir = True + torch._inductor.config.simplefsdp.enable_reorder_ir = True + torch._inductor.config.simplefsdp.simplefsdp_only = True # False for 2d True for 1d + torch._inductor.config.simplefsdp.peak_memory_offset = 0 + torch._inductor.config.simplefsdp.bucketing_type = "auto" + + # Don't use both sets of passes at the same time! + torch._inductor.config.bucket_all_gathers_fx = "none" + torch._inductor.config.bucket_reduce_scatters_fx = "none" + torch._inductor.config.reorder_for_compute_comm_overlap = False + else: + torch._inductor.config.bucket_all_gathers_fx = ( + job_config.experimental.bucket_all_gathers_fx + ) + torch._inductor.config.bucket_reduce_scatters_fx = ( + job_config.experimental.bucket_reduce_scatters_fx + ) + torch._inductor.config.reorder_for_compute_comm_overlap = ( + job_config.experimental.reorder_for_compute_comm_overlap + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = ( + job_config.experimental.reorder_for_compute_comm_overlap_passes + ) + torch._inductor.config.reorder_prefetch_limit = ( + job_config.experimental.reorder_prefetch_limit + ) # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). From 91c56397e8f7c98ec2c685bdf551637f8780b56c Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 19 Aug 2025 13:46:38 -0700 Subject: [PATCH 17/25] Hook up deepseekv3_auto_parallel This command should now run `CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel` However it doesn't actually do anything with autoparallel yet. Next step is to attach local_map to the model so that autoparallel can run. --- .../experiments/auto_parallel/__init__.py | 18 +++ .../auto_parallel/parallelize_deepseekv3.py | 135 ++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py diff --git a/torchtitan/experiments/auto_parallel/__init__.py b/torchtitan/experiments/auto_parallel/__init__.py index 8f5a876b4e..d707d43088 100644 --- a/torchtitan/experiments/auto_parallel/__init__.py +++ b/torchtitan/experiments/auto_parallel/__init__.py @@ -13,7 +13,11 @@ from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer from torchtitan.protocols.train_spec import register_train_spec, TrainSpec +from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers +from torchtitan.models.deepseek_v3 import deepseekv3_configs, DeepSeekV3Model from .parallelize_llama import parallelize_llama +from .parallelize_deepseekv3 import parallelize_deepseekv3 + register_train_spec( TrainSpec( @@ -29,3 +33,17 @@ build_loss_fn=build_cross_entropy_loss, ) ) +register_train_spec( + TrainSpec( + name="deepseekv3_auto_parallel", + cls=DeepSeekV3Model, + config=deepseekv3_configs, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=None, + build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py new file mode 100644 index 0000000000..946ec8a199 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch + +from autoparallel.api import AutoParallel + +from torch.distributed import DeviceMesh +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard + +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +def parallelize_deepseekv3( + model, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return ( + torch.randint( + 0, + # job_config.training.vocab_size, + model.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + # torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( + # lambda bucket_idx: 500 / parallel_dims.tp + # ) + # torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = ( + # lambda bucket_idx: 1000 / parallel_dims.tp + # ) + + # bail out + return model + + # if job_config.experimental.autop_force_bf16: + # logger.info("Forcing bf16 on model") + # model = model.bfloat16() + + # param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] + # reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] + # mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + # with AutoParallel( + # model, + # input_fn, + # world_mesh, + # mp_policy=mp_policy, + # compile=job_config.training.compile, + # ) as autop: + # autop.add_parameter_memory_constraint(low=None, high=None) + + # possible_input_shardings = { + # # maps relative to mesh dim names used in torchtitan + # "dp_replicate": Shard(0), + # "dp_shard": Shard(0), + # "tp": Replicate(), + # } + # # only used if loss parallel is enabled + # possible_output_shardings = { + # # maps relative to mesh dim names used in torchtitan + # "dp_shard": Shard(0), + # "tp": Shard(2), + # } + # assert all( + # name in possible_input_shardings for name in world_mesh.mesh_dim_names + # ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + # x_sharding = tuple( + # possible_input_shardings[name] for name in world_mesh.mesh_dim_names + # ) + # out_sharding = x_sharding + # if parallel_dims.loss_parallel_enabled: + # out_sharding = tuple( + # possible_output_shardings[name] + # for name in world_mesh.mesh_dim_names + # if name != "dp_replicate" + # ) + # autop.add_input_constraints([x_sharding]) + # autop.add_output_constraints([out_sharding]) + # t0 = time.time() + # sharding_placement = autop.optimize_placement() + # t1 = time.time() + # logger.info(f"AutoParallel took {t1 - t0} seconds") + # parallel_mod = autop.apply_placement(sharding_placement) + + if parallel_dims.loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, world_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + + return parallel_mod From 1233902a54e88851f4381349d6df1ecb67134ba7 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 19 Aug 2025 15:49:44 -0700 Subject: [PATCH 18/25] [dsv3] patch graph break fix, works up until sharding rules --- .../auto_parallel/parallelize_deepseekv3.py | 24 +-- torchtitan/models/deepseek_v3/model/moe.py | 140 +++++++++--------- 2 files changed, 81 insertions(+), 83 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py index 946ec8a199..7ef9110acc 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py @@ -44,8 +44,7 @@ def input_fn(): return ( torch.randint( 0, - # job_config.training.vocab_size, - model.vocab_size, + model.model_args.vocab_size, (global_batch_size, job_config.training.seq_len), device=torch.device("cuda"), ), @@ -63,9 +62,6 @@ def input_fn(): # lambda bucket_idx: 1000 / parallel_dims.tp # ) - # bail out - return model - # if job_config.experimental.autop_force_bf16: # logger.info("Forcing bf16 on model") # model = model.bfloat16() @@ -73,13 +69,17 @@ def input_fn(): # param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] # reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] # mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - # with AutoParallel( - # model, - # input_fn, - # world_mesh, - # mp_policy=mp_policy, - # compile=job_config.training.compile, - # ) as autop: + mp_policy = None + with AutoParallel( + model, + input_fn, + world_mesh, + mp_policy=mp_policy, + compile=job_config.training.compile, + ) as autop: + # currently errors due to missing sharding prop rules + torch.distributed.breakpoint() + # autop.add_parameter_memory_constraint(low=None, high=None) # possible_input_shardings = { diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index 2554d61310..86408f82c5 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -48,6 +48,73 @@ def init_weights(self, init_std: float = 0.02): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) +# TODO: keeping this for-loop implementation for comparison +# and readability, may remove later +@expert_parallel +def _run_experts_for_loop( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, +) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = F.silu(torch.matmul(x_expert, w1[expert_idx])) + h = h * torch.matmul(x_expert, w3[expert_idx]) + h = torch.matmul(h, w2[expert_idx]) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, w1)) + h = h * torch.bmm(x, w3) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, w2) + + return out + +@expert_parallel +def _run_experts_grouped_mm( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, +) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + else: + offsets = None + # fall back to regular bmm between 3D tensors + assert x.dim() == 3 + + h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) + h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) + out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) + + return out + class GroupedExperts(nn.Module): def __init__( self, @@ -69,83 +136,14 @@ def forward( num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: if self.use_grouped_mm: - return GroupedExperts._run_experts_grouped_mm( + return _run_experts_grouped_mm( self.w1, self.w2, self.w3, x, num_tokens_per_expert ) else: - return GroupedExperts._run_experts_for_loop( + return _run_experts_for_loop( self.w1, self.w2, self.w3, x, num_tokens_per_expert ) - # TODO: keeping this for-loop implementation for comparison - # and readability, may remove later - @expert_parallel - @staticmethod - def _run_experts_for_loop( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - # NOTE: this would incur a synchronization between device and host - num_tokens_per_expert = num_tokens_per_expert.tolist() - - # side-effect code due to the usage of generate_permute_indices - num_padding = x.shape[0] - sum(num_tokens_per_expert) - - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - h = F.silu(torch.matmul(x_expert, w1[expert_idx])) - h = h * torch.matmul(x_expert, w3[expert_idx]) - h = torch.matmul(h, w2[expert_idx]) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - - # side-effect code due to the usage of generate_permute_indices - out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) - else: - # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, w1)) - h = h * torch.bmm(x, w3) - # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, w2) - - return out - - @expert_parallel - @staticmethod - def _run_experts_grouped_mm( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # grouped mm between a 2D tensor and a 3D tensor - assert x.dim() == 2 - else: - offsets = None - # fall back to regular bmm between 3D tensors - assert x.dim() == 3 - - h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) - h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) - out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) - - return out - def init_weights(self, init_std: float): nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) From 4f8677b79b14d4157b01dde9c26a2c1ac0f8a695 Mon Sep 17 00:00:00 2001 From: ruisizhang123 Date: Thu, 21 Aug 2025 10:20:47 -0700 Subject: [PATCH 19/25] update simplefsdp pass config --- torchtitan/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 1ebdd984ed..e4da1b1f6b 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -131,7 +131,7 @@ def __init__(self, job_config: JobConfig): print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c") raise - # Configs from Ruisi + # Configs from Ruisi # set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives torch._inductor.config.simplefsdp.relax_ratio = 0 @@ -140,10 +140,10 @@ def __init__(self, job_config: JobConfig): torch._inductor.config.simplefsdp.estimate_verbose = False torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl" # set to True after the first communication estimation results are saved. This would reduce decision making time. - torch._inductor.config.simplefsdp.load_cache = False + torch._inductor.config.simplefsdp.load_cache = False torch._inductor.config.simplefsdp.enable_bucket_ir = True torch._inductor.config.simplefsdp.enable_reorder_ir = True - torch._inductor.config.simplefsdp.simplefsdp_only = True # False for 2d True for 1d + torch._inductor.config.simplefsdp.simplefsdp_only = False # False for 2d True for 1d torch._inductor.config.simplefsdp.peak_memory_offset = 0 torch._inductor.config.simplefsdp.bucketing_type = "auto" From 714cc5b395cac355a146fa66ab0b55e434c50f75 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 21 Aug 2025 23:29:37 -0700 Subject: [PATCH 20/25] [dsv3] disable MoE while we fix local_map, works up until optimizer --- .../auto_parallel/parallelize_deepseekv3.py | 71 +++++++++---------- torchtitan/models/deepseek_v3/model/model.py | 4 +- 2 files changed, 37 insertions(+), 38 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py index 7ef9110acc..efda3327c2 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py @@ -77,43 +77,40 @@ def input_fn(): mp_policy=mp_policy, compile=job_config.training.compile, ) as autop: - # currently errors due to missing sharding prop rules - torch.distributed.breakpoint() - - # autop.add_parameter_memory_constraint(low=None, high=None) - - # possible_input_shardings = { - # # maps relative to mesh dim names used in torchtitan - # "dp_replicate": Shard(0), - # "dp_shard": Shard(0), - # "tp": Replicate(), - # } - # # only used if loss parallel is enabled - # possible_output_shardings = { - # # maps relative to mesh dim names used in torchtitan - # "dp_shard": Shard(0), - # "tp": Shard(2), - # } - # assert all( - # name in possible_input_shardings for name in world_mesh.mesh_dim_names - # ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" - # x_sharding = tuple( - # possible_input_shardings[name] for name in world_mesh.mesh_dim_names - # ) - # out_sharding = x_sharding - # if parallel_dims.loss_parallel_enabled: - # out_sharding = tuple( - # possible_output_shardings[name] - # for name in world_mesh.mesh_dim_names - # if name != "dp_replicate" - # ) - # autop.add_input_constraints([x_sharding]) - # autop.add_output_constraints([out_sharding]) - # t0 = time.time() - # sharding_placement = autop.optimize_placement() - # t1 = time.time() - # logger.info(f"AutoParallel took {t1 - t0} seconds") - # parallel_mod = autop.apply_placement(sharding_placement) + autop.add_parameter_memory_constraint(low=None, high=None) + + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + # only used if loss parallel is enabled + possible_output_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_shard": Shard(0), + "tp": Shard(2), + } + assert all( + name in possible_input_shardings for name in world_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in world_mesh.mesh_dim_names + ) + out_sharding = x_sharding + if parallel_dims.loss_parallel_enabled: + out_sharding = tuple( + possible_output_shardings[name] + for name in world_mesh.mesh_dim_names + if name != "dp_replicate" + ) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) if parallel_dims.loss_parallel_enabled: diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 61034e4c7d..946de75b58 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -270,7 +270,9 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.attention = Attention(model_args) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.moe_enabled = layer_id >= model_args.n_dense_layers + # self.moe_enabled = layer_id >= model_args.n_dense_layers + # TODO: enable me when local_map works + self.moe_enabled = False if self.moe_enabled: self.moe = MoE(model_args) From bfa9f7f9c414bcea16b62ea08d1fcdb7463e8937 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 5 Sep 2025 15:10:35 -0400 Subject: [PATCH 21/25] tweak ds3 model.py to reflect main branch for DS3 baseline can run (#1684) --- torchtitan/models/deepseek_v3/model/model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 8f8a512f63..9074919c99 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -270,9 +270,7 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.attention = Attention(model_args) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - # self.moe_enabled = layer_id >= model_args.n_dense_layers - # TODO: enable me when local_map works - self.moe_enabled = False + self.moe_enabled = layer_id >= model_args.n_dense_layers if self.moe_enabled: self.moe = MoE( From 75fb2eb4d0eb9438a7e6a7a0952032995f4f6e2a Mon Sep 17 00:00:00 2001 From: Ruisi Zhang Date: Sat, 6 Sep 2025 00:28:39 -0700 Subject: [PATCH 22/25] add simplefsdp's autobucketing pass entry (#1658) as titled, this pr adds entry to simplefsdp's autobucketing pass in autoparallel. original code is in: https://github.com/pytorch/pytorch/pull/160282 The main code for autobucketing pass will be added to autoparallel repo. --- .../experiments/auto_parallel/README.md | 4 ++ torchtitan/train.py | 38 +++++++++---------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md index ef66a59166..7e112329b9 100644 --- a/torchtitan/experiments/auto_parallel/README.md +++ b/torchtitan/experiments/auto_parallel/README.md @@ -4,4 +4,8 @@ requires installing git@github.com:pytorch-labs/autoparallel.git `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` +Use simplefsdp's autobucketing pass: + +`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --experimental.enable_simplefsdp_passes --compile.enable` + (or llama3-8b.toml) diff --git a/torchtitan/train.py b/torchtitan/train.py index 2829aa3c55..b69b74faac 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -8,6 +8,7 @@ import os import time from datetime import timedelta +from functools import partial from typing import Any, Generator, Iterable, Optional import torch @@ -130,32 +131,29 @@ def __init__(self, job_config: JobConfig): # allow configuring inductor comms optimizations from torchtitan commandline if job_config.experimental.enable_simplefsdp_passes: - try: - from torch._inductor.simple_fsdp.bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir - except ImportError: - print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c") - raise - - # Configs from Ruisi + # enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282) + from autoparallel.auto_bucketing import ( + simple_fsdp_autobucketing_reordering_pass, + simplefsdp_autobucketing_config, + ) - # set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives - torch._inductor.config.simplefsdp.relax_ratio = 0 torch._inductor.config.allow_buffer_reuse = False - torch._inductor.config.simplefsdp.estimate_ir = False - torch._inductor.config.simplefsdp.estimate_verbose = False - torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl" - # set to True after the first communication estimation results are saved. This would reduce decision making time. - torch._inductor.config.simplefsdp.load_cache = False - torch._inductor.config.simplefsdp.enable_bucket_ir = True - torch._inductor.config.simplefsdp.enable_reorder_ir = True - torch._inductor.config.simplefsdp.simplefsdp_only = False # False for 2d True for 1d - torch._inductor.config.simplefsdp.peak_memory_offset = 0 - torch._inductor.config.simplefsdp.bucketing_type = "auto" + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = True + simplefsdp_autobucketing_config.save_estimation_path = ( + "/tmp/torchtitan_simplefsdp_comm_estimation.pkl" + ) + simple_fsdp_autobucketing_reordering_pass = partial( + simple_fsdp_autobucketing_reordering_pass, + configs=simplefsdp_autobucketing_config, + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ + simple_fsdp_autobucketing_reordering_pass + ] # Don't use both sets of passes at the same time! torch._inductor.config.bucket_all_gathers_fx = "none" torch._inductor.config.bucket_reduce_scatters_fx = "none" - torch._inductor.config.reorder_for_compute_comm_overlap = False else: torch._inductor.config.bucket_all_gathers_fx = ( job_config.experimental.bucket_all_gathers_fx From 8769396256c5ad8ba9e9ea15787a32b2bdb8bd91 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 11 Sep 2025 12:26:22 -0700 Subject: [PATCH 23/25] [dsv3] 1D AP w/ local_map --- torchtitan/components/optimizer.py | 43 ++++-- .../auto_parallel/parallelize_deepseekv3.py | 95 +++++++++++- torchtitan/models/moe.py | 137 ++++++++++-------- torchtitan/train.py | 2 +- 4 files changed, 202 insertions(+), 75 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index d3e9628103..1cadefbf71 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -351,12 +351,35 @@ def _update_expert_bias( dp_cp_mesh = ( parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None ) + + ################################################################3 + # AP friendly methods + + def is_moe_block(block): + moe_enabled = getattr(block, "moe_enabled", False) + has_moe_submod = hasattr(block, "moe") # AP + return moe_enabled or has_moe_submod + + def get_transformer_blocks(model_part): + if isinstance(model_part.layers, nn.ModuleDict): + # regular torchtitan + blocks = model_part.layers.values() + else: + # TODO: fix autoparallel to preserve the module dict + blocks = model_part.layers.children() + return blocks + + def should_manual_allreduce(tokens_per_expert_by_layer): + return not isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor) + ################################################################3 + # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. tokens_per_expert_list = [] for model_part in model_parts: - for transformer_block in model_part.layers.values(): - if not transformer_block.moe_enabled: + blocks = get_transformer_blocks(model_part) + for transformer_block in blocks: + if not is_moe_block(transformer_block): continue if transformer_block.moe.load_balance_coeff is None: return @@ -372,17 +395,19 @@ def _update_expert_bias( tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list) if dp_cp_mesh is not None: - # Perform single all-reduce to get global statistics across all processes - pg = dp_cp_mesh.get_group() - torch.distributed.all_reduce( - tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM - ) + if should_manual_allreduce(tokens_per_expert_by_layer): + # Perform single all-reduce to get global statistics across all processes + pg = dp_cp_mesh.get_group() + torch.distributed.all_reduce( + tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM + ) moe_layer_idx = 0 with torch.no_grad(): for model_part in model_parts: - for transformer_block in model_part.layers.values(): - if not transformer_block.moe_enabled: + blocks = get_transformer_blocks(model_part) + for transformer_block in blocks: + if not is_moe_block(transformer_block): continue moe = transformer_block.moe diff --git a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py index c1991dfa5a..cf69511e0a 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py @@ -19,14 +19,59 @@ from torchtitan.tools.logging import logger +def apply_local_map_to_moe(): + """ + TODO: fix HOPs not restoring the original signature. + TODO: fix tracing with local shapes so that we can use Shard placements + + Current HOP signature we get: + + class subgraph_0(torch.nn.Module): + def forward(self, + rms_norm_5: "f32[64, 2048, 256][524288, 256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__router____modules__gate____parameters__weight: "f32[8, 256][256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____buffers__expert_bias: "f32[8][1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w1: "f32[8, 256, 256][65536, 256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w3: "f32[8, 256, 256][65536, 256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w2: "f32[8, 256, 256][65536, 256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w1____parameters__weight: "f32[512, 256][256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w3____parameters__weight: "f32[512, 256][256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w2____parameters__weight: "f32[256, 512][512, 1]cuda:0"): + """ + from torchtitan.models import moe + from torch.distributed._tensor.experimental import local_map + moe._moe_forward = local_map( + moe._moe_forward, + out_placements=( + (Replicate(),), # (Shard(0),), + (Replicate(),), + ), + in_placements=( + (Replicate(),), # (Shard(0),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + ), + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=None, + ) + + +# Run workflow with: +# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel def parallelize_deepseekv3( model, parallel_dims: ParallelDims, job_config: JobConfig, ): """ - Apply tensor parallelism, activation checkpointing, torch.compile, and data - parallelism to the model. + Apply Autoparallel to the model NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. @@ -54,6 +99,9 @@ def input_fn(): assert parallel_dims.cp_enabled is False, "CP not supported yet" assert parallel_dims.pp_enabled is False, "PP not supported yet" + # apply local_map to MoE + apply_local_map_to_moe() + # torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( # lambda bucket_idx: 500 / parallel_dims.tp # ) @@ -131,4 +179,47 @@ def _return_as_dtensor_for_loss_parallel(module, args, output): # removing it at any point parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + _preserve_moe_attributes(model, parallel_mod) + return parallel_mod + + +def _preserve_moe_attributes(original_model, parallel_model): + """ + Preserve MoE custom attributes from the original model to the parallel model. + This is only needed for attributes that aren't used in the graph, so they aren't + lifted as graph inputs and fetched by the pre-graph runtime wrapper. + + `moe_enabled` ane `load_balance_coeff` are used later in the optimizer to identify + this block as a moe block. This should be safe as they are read-only. + """ + def get_moe_modules(model): + """Extract all MoE modules from the model.""" + moe_modules = [] + if hasattr(model, 'layers'): + if isinstance(model.layers, torch.nn.ModuleDict): + # regular torchtitan structure + blocks = model.layers.values() + else: + # autoparallel might change structure + blocks = model.layers.children() if hasattr(model.layers, 'children') else [] + + for block in blocks: + if hasattr(block, 'moe_enabled') and block.moe_enabled and hasattr(block, 'moe'): + moe_modules.append(block.moe) + elif hasattr(block, 'moe'): # fallback for autoparallel + moe_modules.append(block.moe) + return moe_modules + + original_moe_modules = get_moe_modules(original_model) + parallel_moe_modules = get_moe_modules(parallel_model) + + # Copy custom attributes from original to parallel MoE modules + # This is fine to do since these attributes are read only + for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules): + if hasattr(orig_moe, 'moe_enabled'): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff + + # Copy load_balance_coeff + if hasattr(orig_moe, 'load_balance_coeff'): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 8be14ecbf0..1ec8e3b23b 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -12,6 +12,7 @@ from torch import nn from torchtitan.distributed.expert_parallel import expert_parallel +from torch.distributed.tensor.placement_types import Shard, Replicate @dataclass @@ -310,6 +311,77 @@ def forward( num_tokens_per_expert, ) +def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, experts, shared_experts): + # x: 64, 2048, 256 + bs, slen, dim = x.shape + x = x.view(-1, dim) + + # top_scores and selected_experts_indices shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + selected_experts_indices, + num_tokens_per_expert, + ) = router(x, expert_bias) + + # tokens_per_expert will be used to update the expert bias for load balancing. + # and also to count the expert usage + # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- + # first in the forward pass, and then in the backward pass. However, this has no + # effect on the expert bias update thanks to the torch.sign() operator. + # moved out to remove mutation + # with torch.no_grad(): + # tokens_per_expert.add_(num_tokens_per_expert) + + # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + # NOTE: the reason we need to compute num_tokens_per_expert again is: + # 1st computation in router is to update self.tokens_per_expert + # which would be the same across all TP ranks. + # 2nd computation in reorderer is for the actual routing and experts computation + # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. + # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. + ( + top_scores_experts_sorted, + token_indices_experts_sorted, + num_tokens_per_expert, + ) = reorderer(top_scores, selected_experts_indices) + + # shape (bs*slen*top_k, dim) + token_indices_experts_sorted = token_indices_experts_sorted.reshape( + -1, 1 + ).expand(-1, dim) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) + + if score_before_experts: + routed_input = ( + routed_input.to(torch.float32) + * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) + + # shape (bs*slen*top_k, dim) + routed_output = experts(routed_input, num_tokens_per_expert) + + if not score_before_experts: + routed_output = ( + routed_output.to(torch.float32) + * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) + + # shared expert + if shared_experts is not None: + out = shared_experts(x) + else: + out = torch.zeros_like(x) + + out = out.scatter_add( + dim=0, index=token_indices_experts_sorted, src=routed_output + ) + out = out.reshape(bs, slen, dim) + return out, num_tokens_per_expert + class MoE(nn.Module): def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): @@ -367,72 +439,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ - bs, slen, dim = x.shape - x = x.view(-1, dim) - - # top_scores and selected_experts_indices shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) - ( - top_scores, - selected_experts_indices, - num_tokens_per_expert, - ) = self.router(x, self.expert_bias) + out, num_tokens_per_expert = _moe_forward(x, self.router, self.expert_bias, self.reorderer, self.score_before_experts, self.experts, self.shared_experts) - # tokens_per_expert will be used to update the expert bias for load balancing. - # and also to count the expert usage - # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- - # first in the forward pass, and then in the backward pass. However, this has no - # effect on the expert bias update thanks to the torch.sign() operator. + # HOPs don't support buffer mutations, keep this outside with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) - - # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) - # NOTE: the reason we need to compute num_tokens_per_expert again is: - # 1st computation in router is to update self.tokens_per_expert - # which would be the same across all TP ranks. - # 2nd computation in reorderer is for the actual routing and experts computation - # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. - # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. - ( - top_scores_experts_sorted, - token_indices_experts_sorted, - num_tokens_per_expert, - ) = self.reorderer(top_scores, selected_experts_indices) - - # shape (bs*slen*top_k, dim) - token_indices_experts_sorted = token_indices_experts_sorted.reshape( - -1, 1 - ).expand(-1, dim) - - # shape (bs*slen*top_k, dim) - routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) - - if self.score_before_experts: - routed_input = ( - routed_input.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) - - # shape (bs*slen*top_k, dim) - routed_output = self.experts(routed_input, num_tokens_per_expert) - - if not self.score_before_experts: - routed_output = ( - routed_output.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) - - # shared expert - if self.shared_experts is not None: - out = self.shared_experts(x) - else: - out = torch.zeros_like(x) - - out = out.scatter_add( - dim=0, index=token_indices_experts_sorted, src=routed_output - ) - out = out.reshape(bs, slen, dim) return out def init_weights( diff --git a/torchtitan/train.py b/torchtitan/train.py index b69b74faac..3396fa56d2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -307,7 +307,7 @@ def __init__(self, job_config: JobConfig): # confirm that user will be able to view loss metrics on the console ensure_pp_loss_visible(parallel_dims, job_config, color) else: - # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + # apply Autoparallel model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) model.to_empty(device=init_device) From db224791abca94707ead46efe4de4d776ec37e7d Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Wed, 17 Sep 2025 16:00:36 -0700 Subject: [PATCH 24/25] [dsv3] Turn off Flex for AP --- torchtitan/models/deepseek_v3/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 1c3d2b19d2..bf3232fd04 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -100,8 +100,8 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( vocab_size=102400, @@ -127,8 +127,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", ), "671B": DeepSeekV3ModelArgs( vocab_size=129280, @@ -154,8 +154,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", ), } From 45b15f6addfe6c903eb4b6811808133fb91cc43e Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Fri, 12 Sep 2025 01:00:58 -0700 Subject: [PATCH 25/25] [autoparallel] Add experimental config to enable autoparallel_asynctp stack-info: PR: https://github.com/pytorch/torchtitan/pull/1772, branch: IvanKobzarev/stack/2 --- torchtitan/config/job_config.py | 5 ++ .../auto_parallel/parallelize_llama.py | 50 ++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 8571e5680c..8e10785bdf 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -739,6 +739,11 @@ class Experimental: enable_simplefsdp_passes: bool = False + enable_inductor_aten_fx_overlap_scheduler: bool = False + enable_inductor_aten_fx_overlap_scheduler_bucketing: bool = False + enable_autoparallel_asynctp: bool = False + + @dataclass class Validation: enable: bool = False diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 6648f29ab8..9d53d9a755 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -10,7 +10,6 @@ from autoparallel.api import AutoParallel -from torch.distributed import DeviceMesh from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard @@ -33,6 +32,7 @@ def parallelize_llama( the model must fit on GPU or CPU memory. """ world_mesh = parallel_dims.world_mesh + def input_fn(): global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: @@ -62,6 +62,51 @@ def input_fn(): lambda bucket_idx: 1000 / parallel_dims.tp ) + enable_overlap_scheduling = ( + job_config.experimental.enable_inductor_aten_fx_overlap_scheduler + ) + enable_overlap_scheduling_bucketing = ( + job_config.experimental.enable_inductor_aten_fx_overlap_scheduler_bucketing + ) + if enable_overlap_scheduling_bucketing: + assert ( + enable_overlap_scheduling + ), "bucketing can not be used without overlap scheduling" + + if enable_overlap_scheduling: + from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler + + torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = ( + enable_overlap_scheduling_bucketing + ) + + def _overlap_bucketing_pass(graph): + overlap_scheduler = OverlapScheduler(graph.owning_module) + overlap_scheduler.run() + + torch._inductor.config.post_grad_custom_post_pass = _overlap_bucketing_pass + + if job_config.experimental.enable_autoparallel_asynctp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + assert "tp" in world_mesh.mesh_dim_names + enable_symm_mem_for_group(world_mesh["tp"].get_group().group_name) + torch._inductor.config._micro_pipeline_tp = False + # Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork. + from autoparallel.asynctp import micro_pipeline_tp_pass + + existing_post_grad_custom_post_pass = ( + torch._inductor.config.post_grad_custom_post_pass + ) + + def _pass(graph): + if existing_post_grad_custom_post_pass is not None: + existing_post_grad_custom_post_pass(graph) + + micro_pipeline_tp_pass(graph, None) + + torch._inductor.config.post_grad_custom_post_pass = _pass + # bail out # model = model_fn() # return model @@ -101,7 +146,8 @@ def input_fn(): ) out_sharding = x_sharding loss_parallel_enabled = ( - parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel ) if loss_parallel_enabled: out_sharding = tuple(