From 3ccd12c87667e93dbacd72c0da24060c85421c96 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 12 Jun 2025 21:08:11 -0700 Subject: [PATCH 01/49] [WIP] Integrate autoparallel into torchtitan TODO - try converting model params into fake tensors - figure out init fn - integrate torchtitan configs for DP/TP to control autop Hack an init_fn for llama3 and observe loss decreasing with autoparallel """ [rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step: 1 loss: 8.1880 memory: 4.88GiB(6.16%) tps: 28 [rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step: 2 loss: 8.1610 memory: 4.90GiB(6.20%) tps: 13,785 [rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step: 3 loss: 8.0871 memory: 4.90GiB(6.20%) tps: 14,006 [rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step: 4 loss: 7.9516 memory: 4.90GiB(6.20%) tps: 13,770 [rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step: 5 loss: 7.8552 memory: 4.90GiB(6.20%) tps: 13,959 [rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step: 6 loss: 7.7732 memory: 4.90GiB(6.20%) tps: 13,859 [rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step: 7 loss: 7.6987 memory: 4.90GiB(6.20%) tps: 13,664 [rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step: 8 loss: 7.6779 memory: 4.90GiB(6.20%) tps: 13,985 [rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step: 9 loss: 7.6043 memory: 4.90GiB(6.20%) tps: 13,962 [rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10 loss: 7.5778 memory: 4.90GiB(6.20%) tps: 13,891 """ Adopt new autoparallel API with meta-init model Allows reverting a lot of the hacks in the original integration that were caused by not creating a model obj in the train.py due to passing a model_fn builder to autop. Fixes to align with latest autoparallel Add inductor config knobs for comms optimizations to torchtitan Make inductor always run compile passes basically, this is an annoying workaround for debugging iteratively. 1- you run the model, it compiles, but something weird happens 2- you enable some logging or tlparse, rerun. but inductor decides not to run your pass anymore, its results are cached. since (2) has confused me horribly on more than one occasion, i just disable caching for now Drop hacky llama3_init_fn and use autop init_weights feature Relying on https://github.com/pytorch-labs/autoparallel/pull/20, this lets us automatically apply a user's init_weights fn to the autoparallel model. Verified this works with `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4` ``` [rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step: 1 loss: 8.1848 memory: 1.09GiB(1.14%) tps: 77 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step: 2 loss: 8.1619 memory: 1.15GiB(1.21%) tps: 48,138 tflops: 3.46 mfu: 0.35 % [rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step: 3 loss: 8.1140 memory: 1.15GiB(1.21%) tps: 88,440 tflops: 6.36 mfu: 0.64 % [rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step: 4 loss: 8.0099 memory: 1.15GiB(1.21%) tps: 82,626 tflops: 5.94 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step: 5 loss: 7.8928 memory: 1.15GiB(1.21%) tps: 81,594 tflops: 5.87 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step: 6 loss: 7.7758 memory: 1.15GiB(1.21%) tps: 79,607 tflops: 5.72 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step: 7 loss: 7.6221 memory: 1.15GiB(1.21%) tps: 81,448 tflops: 5.86 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step: 8 loss: 7.5578 memory: 1.15GiB(1.21%) tps: 79,732 tflops: 5.73 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step: 9 loss: 7.3851 memory: 1.15GiB(1.21%) tps: 85,655 tflops: 6.16 mfu: 0.62 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10 loss: 7.3361 memory: 1.15GiB(1.21%) tps: 81,855 tflops: 5.89 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete ``` fix lint --- torchtitan/config_manager.py | 22 ++++++ torchtitan/experiments/__init__.py | 1 + .../experiments/auto_parallel/README.md | 7 ++ .../experiments/auto_parallel/__init__.py | 31 ++++++++ .../auto_parallel/parallelize_llama.py | 77 +++++++++++++++++++ torchtitan/train.py | 27 +++++-- 6 files changed, 159 insertions(+), 6 deletions(-) create mode 100644 torchtitan/experiments/auto_parallel/README.md create mode 100644 torchtitan/experiments/auto_parallel/__init__.py create mode 100644 torchtitan/experiments/auto_parallel/parallelize_llama.py diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 5f1a1e8b7f..1a45a7800c 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -664,6 +664,28 @@ class Experimental: needs to ensure that the path can be imported. """ + reorder_for_compute_comm_overlap: bool = False + """ + Whether to enable inductor comm reordering passes + """ + + reorder_for_compute_comm_overlap_passes: list[str] = field( + default_factory=lambda: [ + "sink_waits", + "reorder_communication_preserving_peak_memory", + ] + ) + """ + Sequence of reordering passes (names of functions inside _inductor.comms) to call, + if reorder_for_compute_comm_overlap is enabled. + """ + + reorder_prefetch_limit: int | None = None + """ + How many ops to allow moving any individual collective, if 'reorder_communication_preserving_peak_memory' + pass is enabled. default of None means unlimited + """ + @dataclass class Validation: diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 4c54bdc13e..b7ff983e97 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -4,5 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torchtitan.experiments.auto_parallel # noqa: F401 import torchtitan.experiments.llama4 # noqa: F401 import torchtitan.experiments.simple_fsdp # noqa: F401 diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md new file mode 100644 index 0000000000..ef66a59166 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/README.md @@ -0,0 +1,7 @@ +## Auto Parallel + +requires installing git@github.com:pytorch-labs/autoparallel.git + +`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` + +(or llama3-8b.toml) diff --git a/torchtitan/experiments/auto_parallel/__init__.py b/torchtitan/experiments/auto_parallel/__init__.py new file mode 100644 index 0000000000..8f5a876b4e --- /dev/null +++ b/torchtitan/experiments/auto_parallel/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec +from .parallelize_llama import parallelize_llama + +register_train_spec( + TrainSpec( + name="llama3_auto_parallel", + cls=Transformer, + config=llama3_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py new file mode 100644 index 0000000000..bb7f1204df --- /dev/null +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch + +from autoparallel.api import AutoParallel + +from torch.distributed import DeviceMesh +from torch.distributed.tensor.placement_types import Replicate, Shard + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +def parallelize_llama( + model, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = world_mesh["dp"].size() + global_batch_size = job_config.training.local_batch_size * dp_degree + return torch.rand( + (global_batch_size, job_config.training.seq_len), device="cuda" + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert ( + len(world_mesh.shape) == 2 + ), "Only support 2D mesh (DP, TP) for now- OK if one has size=1" + assert parallel_dims.dp_shard_enabled is True, "DDP not supported yet" + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + # bail out + # model = model_fn() + # return model + + autop = AutoParallel(model, input_fn, world_mesh) + autop.add_parameter_memory_constraint(low=None, high=None) + + x_sharding = (Shard(0), Replicate()) + + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + if job_config.training.compile: + torch._inductor.config.reorder_for_peak_memory = False + parallel_mod = torch.compile(parallel_mod, fullgraph=True) + + return parallel_mod diff --git a/torchtitan/train.py b/torchtitan/train.py index 3dc8a61b28..9626d8a5a4 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,6 +12,7 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.tensor import DTensor import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module @@ -116,6 +117,21 @@ def __init__(self, job_config: JobConfig): gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) + # TODO(whc) + # I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + + # allow configuring inductor comms optimizations from torchtitan commandline + torch._inductor.config.reorder_for_compute_comm_overlap = ( + job_config.experimental.reorder_for_compute_comm_overlap + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = ( + job_config.experimental.reorder_for_compute_comm_overlap_passes + ) + torch._inductor.config.reorder_prefetch_limit = ( + job_config.experimental.reorder_prefetch_limit + ) + # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( @@ -141,20 +157,19 @@ def __init__(self, job_config: JobConfig): ) # build model (using meta init) - model_cls = self.train_spec.cls model_args = self.train_spec.config[job_config.model.flavor] + model_cls = self.train_spec.cls # set the model args from training job configs model_args.update_from_config(job_config, tokenizer) - logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) + with torch.device("meta"): model = model_cls(model_args) - - # Build the collection of model converters. No-op if `model.converters` empty - model_converters = build_model_converters(job_config, parallel_dims) - model_converters.convert(model) + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) # metrics logging build_metrics_processor_fn = ( From e6d2cafc4b33868a49af2e9c5bd6563994f4fa64 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 26 Jun 2025 18:34:34 -0700 Subject: [PATCH 02/49] Autoparallel support for DP-only, DP+TP, or TP-only lets existing torchtitan knobs which govern DP/TP mesh creation and mesh size influence the sharding constraints of autoparallel, allowing it to support these different sharding configurations. --- .../auto_parallel/parallelize_llama.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index bb7f1204df..6e0d9b4dcb 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -38,17 +38,13 @@ def input_fn(): if global_batch_size < 0: # This global batch size results in 1 gradient accumulation # step. - dp_degree = world_mesh["dp"].size() + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard global_batch_size = job_config.training.local_batch_size * dp_degree return torch.rand( (global_batch_size, job_config.training.seq_len), device="cuda" ) # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP - assert ( - len(world_mesh.shape) == 2 - ), "Only support 2D mesh (DP, TP) for now- OK if one has size=1" - assert parallel_dims.dp_shard_enabled is True, "DDP not supported yet" assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" assert parallel_dims.cp_enabled is False, "CP not supported yet" assert parallel_dims.pp_enabled is False, "PP not supported yet" @@ -60,8 +56,18 @@ def input_fn(): autop = AutoParallel(model, input_fn, world_mesh) autop.add_parameter_memory_constraint(low=None, high=None) - x_sharding = (Shard(0), Replicate()) - + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + assert all( + name in possible_input_shardings for name in world_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in world_mesh.mesh_dim_names + ) autop.add_input_constraints([x_sharding]) autop.add_output_constraints([x_sharding]) t0 = time.time() From 68476b3fea19dda14f760ea246aac85194c37b39 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 25 Jul 2025 09:25:24 -0700 Subject: [PATCH 03/49] Update CLI inductor configs for bucketing/reordering --- torchtitan/config_manager.py | 8 +++++++- torchtitan/train.py | 8 ++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1a45a7800c..44567a0fdb 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -664,6 +664,12 @@ class Experimental: needs to ensure that the path can be imported. """ + # "none", "all", "only_fsdp" + bucket_all_gathers_fx: str | None = None + + # "none", "all" + bucket_reduce_scatters_fx: str | None = None + reorder_for_compute_comm_overlap: bool = False """ Whether to enable inductor comm reordering passes @@ -671,7 +677,7 @@ class Experimental: reorder_for_compute_comm_overlap_passes: list[str] = field( default_factory=lambda: [ - "sink_waits", + "sink_waits_iterative", "reorder_communication_preserving_peak_memory", ] ) diff --git a/torchtitan/train.py b/torchtitan/train.py index 9626d8a5a4..8ca90037b0 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -120,8 +120,16 @@ def __init__(self, job_config: JobConfig): # TODO(whc) # I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False # allow configuring inductor comms optimizations from torchtitan commandline + torch._inductor.config.bucket_all_gathers_fx = ( + job_config.experimental.bucket_all_gathers_fx + ) + torch._inductor.config.bucket_reduce_scatters_fx = ( + job_config.experimental.bucket_reduce_scatters_fx + ) torch._inductor.config.reorder_for_compute_comm_overlap = ( job_config.experimental.reorder_for_compute_comm_overlap ) From 9ee9f75bfd4a009da580922d6cd07cf276db85c9 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 25 Jul 2025 12:07:17 -0700 Subject: [PATCH 04/49] add back llama3_autoparallel_init_fn --- torchtitan/train.py | 75 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 8ca90037b0..35435eecab 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -238,6 +238,78 @@ def __init__(self, job_config: JobConfig): self.loss_fn, self.gradient_accumulation_steps ) + def llama3_autoparallel_init_fn(model): + # WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying + # code from the llama3 init_weights functions throughout the model components, and adjusting them to use + # the new FQN structures in autoparallel. + # TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module + def param(name): + return model.get_parameter(f"params.{name}") + + from torchtitan.models.llama3.model.model import precompute_freqs_cis + + model.buffers_.get_buffer("freqs_cis").copy_( + DTensor.from_local( + precompute_freqs_cis( + model_args.dim // model_args.n_heads, + model_args.max_seq_len, + model_args.rope_theta, + ), + device_mesh=model.buffers_.get_buffer("freqs_cis").device_mesh, + ) + ) + + torch.nn.init.normal_(param("tok_embeddings/weight")) + + def init_layer(i): + for norm in ("attention_norm", "ffn_norm"): + torch.nn.init.ones_(param(f"layers/{i}/{norm}/weight")) + + if model_args.depth_init: + weight_init_std = 0.02 / (2 * (i + 1)) ** 0.5 + else: + weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + for linear in ("wq", "wk", "wv"): + torch.nn.init.trunc_normal_( + param(f"layers/{i}/attention/{linear}/weight"), + mean=0.0, + std=0.02, + ) + torch.nn.init.trunc_normal_( + param(f"layers/{i}/attention/wo/weight"), + mean=0.0, + std=weight_init_std, + ) + + torch.nn.init.trunc_normal_( + param(f"layers/{i}/feed_forward/w1/weight"), mean=0.0, std=0.02 + ) + for linear in ("w2", "w3"): + torch.nn.init.trunc_normal_( + param(f"layers/{i}/feed_forward/{linear}/weight"), + mean=0.0, + std=weight_init_std, + ) + + for i in range(model_args.n_layers): + init_layer(i) + + if param("norm/weight") is not None: + torch.nn.init.ones_(param("norm/weight")) + + final_out_std = model_args.dim**-0.5 + cutoff_factor = 3 + + if param("output/weight") is not None: + torch.nn.init.trunc_normal_( + param("output/weight"), + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + # apply parallelisms and initialization if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: @@ -282,7 +354,8 @@ def __init__(self, job_config: JobConfig): model.to_empty(device=init_device) with torch.no_grad(): - model.init_weights(buffer_device=buffer_device) + # model.init_weights(buffer_device=buffer_device) + llama3_autoparallel_init_fn(model) model.train() self.model_parts = [model] From f6e4099cfa36f00a2b654fef2483ebc55d1a738c Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 27 Jul 2025 19:53:10 -0700 Subject: [PATCH 05/49] Track API change from new AOTAutograd interface Signed-off-by: Edward Z. Yang --- .../auto_parallel/parallelize_llama.py | 44 +++++++++---------- torchtitan/train.py | 6 +-- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 6e0d9b4dcb..4b0e7a3e03 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -53,28 +53,28 @@ def input_fn(): # model = model_fn() # return model - autop = AutoParallel(model, input_fn, world_mesh) - autop.add_parameter_memory_constraint(low=None, high=None) - - possible_input_shardings = { - # maps relative to mesh dim names used in torchtitan - "dp_replicate": Shard(0), - "dp_shard": Shard(0), - "tp": Replicate(), - } - assert all( - name in possible_input_shardings for name in world_mesh.mesh_dim_names - ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" - x_sharding = tuple( - possible_input_shardings[name] for name in world_mesh.mesh_dim_names - ) - autop.add_input_constraints([x_sharding]) - autop.add_output_constraints([x_sharding]) - t0 = time.time() - sharding_placement = autop.optimize_placement() - t1 = time.time() - logger.info(f"AutoParallel took {t1 - t0} seconds") - parallel_mod = autop.apply_placement(sharding_placement) + with AutoParallel(model, input_fn, world_mesh) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + assert all( + name in possible_input_shardings for name in world_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in world_mesh.mesh_dim_names + ) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) if job_config.training.compile: torch._inductor.config.reorder_for_peak_memory = False diff --git a/torchtitan/train.py b/torchtitan/train.py index 35435eecab..cd85e27045 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -244,18 +244,18 @@ def llama3_autoparallel_init_fn(model): # the new FQN structures in autoparallel. # TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module def param(name): - return model.get_parameter(f"params.{name}") + return model.get_parameter(f"{name.replace('/', '.')}") from torchtitan.models.llama3.model.model import precompute_freqs_cis - model.buffers_.get_buffer("freqs_cis").copy_( + model.get_buffer("freqs_cis").copy_( DTensor.from_local( precompute_freqs_cis( model_args.dim // model_args.n_heads, model_args.max_seq_len, model_args.rope_theta, ), - device_mesh=model.buffers_.get_buffer("freqs_cis").device_mesh, + device_mesh=model.get_buffer("freqs_cis").device_mesh, ) ) From 4d7ee8a0b864525e8194876d4d91ea2f6e9be743 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 28 Jul 2025 14:30:49 -0700 Subject: [PATCH 06/49] Support forcing the model into bf16 for perf debugging --- torchtitan/config_manager.py | 1 + torchtitan/experiments/auto_parallel/parallelize_llama.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 44567a0fdb..ee7de65105 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -692,6 +692,7 @@ class Experimental: pass is enabled. default of None means unlimited """ + autop_force_bf16: bool = False @dataclass class Validation: diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 4b0e7a3e03..a8f7b217cd 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -52,6 +52,9 @@ def input_fn(): # bail out # model = model_fn() # return model + if job_config.experimental.autop_force_bf16: + logger.info("Forcing bf16 on model") + model = model.bfloat16() with AutoParallel(model, input_fn, world_mesh) as autop: autop.add_parameter_memory_constraint(low=None, high=None) From b801d0b5b307617e25e9563694397673cb3041ee Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 29 Jul 2025 10:56:50 -0700 Subject: [PATCH 07/49] Integrate MixedPrecision with AutoParallel and fix example_inputs --- .../auto_parallel/parallelize_llama.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index a8f7b217cd..bb03e81c95 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -11,9 +11,10 @@ from autoparallel.api import AutoParallel from torch.distributed import DeviceMesh +from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger @@ -40,9 +41,13 @@ def input_fn(): # step. dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard global_batch_size = job_config.training.local_batch_size * dp_degree - return torch.rand( - (global_batch_size, job_config.training.seq_len), device="cuda" - ) + return torch.randint( + 0, + # job_config.training.vocab_size, + model.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" @@ -56,7 +61,10 @@ def input_fn(): logger.info("Forcing bf16 on model") model = model.bfloat16() - with AutoParallel(model, input_fn, world_mesh) as autop: + param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] + reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + with AutoParallel(model, input_fn, world_mesh, mp_policy=mp_policy) as autop: autop.add_parameter_memory_constraint(low=None, high=None) possible_input_shardings = { From b099cf9c06b11084b01645a7878ff8ee91d82ef3 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 29 Jul 2025 13:34:16 -0700 Subject: [PATCH 08/49] Use in-place compile API Signed-off-by: Edward Z. Yang --- torchtitan/experiments/auto_parallel/parallelize_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index bb03e81c95..39ca4ab1da 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -89,6 +89,6 @@ def input_fn(): if job_config.training.compile: torch._inductor.config.reorder_for_peak_memory = False - parallel_mod = torch.compile(parallel_mod, fullgraph=True) + parallel_mod.compile(fullgraph=True) return parallel_mod From b3587d9fec06fee75fd5398673831a95a08b083c Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 29 Jul 2025 13:14:05 -0700 Subject: [PATCH 09/49] Fix bucketing pass configs - fix passing of "none" (not None) to control bucketing passes --- torchtitan/config_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index ee7de65105..131d7d2bd8 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -665,10 +665,10 @@ class Experimental: """ # "none", "all", "only_fsdp" - bucket_all_gathers_fx: str | None = None + bucket_all_gathers_fx: str = "none" # "none", "all" - bucket_reduce_scatters_fx: str | None = None + bucket_reduce_scatters_fx: str = "none" reorder_for_compute_comm_overlap: bool = False """ From 42c2c07613d0239c5b630bcfda534db2ffd66ab1 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 30 Jul 2025 13:06:14 -0700 Subject: [PATCH 10/49] Support both eager and autoparallel init based on model.name --- torchtitan/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index cd85e27045..793957a761 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -354,8 +354,10 @@ def init_layer(i): model.to_empty(device=init_device) with torch.no_grad(): - # model.init_weights(buffer_device=buffer_device) - llama3_autoparallel_init_fn(model) + if job_config.model.name == "llama3_auto_parallel": + llama3_autoparallel_init_fn(model) + else: + model.init_weights(buffer_device=buffer_device) model.train() self.model_parts = [model] From d93845e57e88c2fa9dc7eeda6e754129eb01660f Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 6 Aug 2025 11:57:11 -0700 Subject: [PATCH 11/49] Remove llama3 init weights hack since https://github.com/meta-pytorch/autoparallel/pull/52 landed --- torchtitan/train.py | 77 +-------------------------------------------- 1 file changed, 1 insertion(+), 76 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 793957a761..8ca90037b0 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -238,78 +238,6 @@ def __init__(self, job_config: JobConfig): self.loss_fn, self.gradient_accumulation_steps ) - def llama3_autoparallel_init_fn(model): - # WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying - # code from the llama3 init_weights functions throughout the model components, and adjusting them to use - # the new FQN structures in autoparallel. - # TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module - def param(name): - return model.get_parameter(f"{name.replace('/', '.')}") - - from torchtitan.models.llama3.model.model import precompute_freqs_cis - - model.get_buffer("freqs_cis").copy_( - DTensor.from_local( - precompute_freqs_cis( - model_args.dim // model_args.n_heads, - model_args.max_seq_len, - model_args.rope_theta, - ), - device_mesh=model.get_buffer("freqs_cis").device_mesh, - ) - ) - - torch.nn.init.normal_(param("tok_embeddings/weight")) - - def init_layer(i): - for norm in ("attention_norm", "ffn_norm"): - torch.nn.init.ones_(param(f"layers/{i}/{norm}/weight")) - - if model_args.depth_init: - weight_init_std = 0.02 / (2 * (i + 1)) ** 0.5 - else: - weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 - - for linear in ("wq", "wk", "wv"): - torch.nn.init.trunc_normal_( - param(f"layers/{i}/attention/{linear}/weight"), - mean=0.0, - std=0.02, - ) - torch.nn.init.trunc_normal_( - param(f"layers/{i}/attention/wo/weight"), - mean=0.0, - std=weight_init_std, - ) - - torch.nn.init.trunc_normal_( - param(f"layers/{i}/feed_forward/w1/weight"), mean=0.0, std=0.02 - ) - for linear in ("w2", "w3"): - torch.nn.init.trunc_normal_( - param(f"layers/{i}/feed_forward/{linear}/weight"), - mean=0.0, - std=weight_init_std, - ) - - for i in range(model_args.n_layers): - init_layer(i) - - if param("norm/weight") is not None: - torch.nn.init.ones_(param("norm/weight")) - - final_out_std = model_args.dim**-0.5 - cutoff_factor = 3 - - if param("output/weight") is not None: - torch.nn.init.trunc_normal_( - param("output/weight"), - mean=0.0, - std=final_out_std, - a=-cutoff_factor * final_out_std, - b=cutoff_factor * final_out_std, - ) - # apply parallelisms and initialization if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: @@ -354,10 +282,7 @@ def init_layer(i): model.to_empty(device=init_device) with torch.no_grad(): - if job_config.model.name == "llama3_auto_parallel": - llama3_autoparallel_init_fn(model) - else: - model.init_weights(buffer_device=buffer_device) + model.init_weights(buffer_device=buffer_device) model.train() self.model_parts = [model] From 60f5f118f4c17747bb117d4a8018eec180e0bd89 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 7 Aug 2025 15:25:02 -0700 Subject: [PATCH 12/49] Print profiling manifold url prints an (internal, vpn) only link for each profile trace file that's saved to manifold. Just search for 'trace' in your job logs on mast, and click one of the rank links. e.g. [trainer37|5]:[titan] 2025-08-07 14:21:01,227 - root - INFO - Finished dumping profiler traces in 5.22 seconds: https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/torchtrain_datasets/tree/outputs/torchtitan-64-whc-jv2j4mp/profile_trace/iteration_20/rank37_trace.json --- torchtitan/tools/profiling.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 050b992cc8..7c822c3b66 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -19,7 +19,9 @@ # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 - +PERFETTO_UI_ROOT_URL = ( + "https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html" +) @contextlib.contextmanager def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): @@ -42,10 +44,22 @@ def trace_handler(prof): logger.info(f"Dumping profiler traces at step {prof.step_num}") begin = time.monotonic() - prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") - logger.info( - f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" - ) + filename = f"{curr_trace_dir}/rank{rank}_trace.json" + + prof.export_chrome_trace(filename) + log_str = f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" + # not directly landable on upstream titan, + # but conveniently prints the internal url for perfetto on manifold for mast jobs + manifold_mount_prefix = "/mnt/mffuse/" + if filename.find(manifold_mount_prefix) == 0: + manifold_path = os.path.join("torchtrain_datasets/tree", filename.split(manifold_mount_prefix)[1]) + perfetto_url = ( + PERFETTO_UI_ROOT_URL + + "#!/?url=https://interncache-all.fbcdn.net/manifold/" + + manifold_path + ) + log_str += f": {perfetto_url}" + logger.info(log_str) logger.info(f"Profiling active. Traces will be saved at {trace_dir}") From 6c782eba53b51ee2f95af0be98f49c443ff9b52f Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 7 Aug 2025 20:07:39 -0700 Subject: [PATCH 13/49] Support new compile API from autoparallel PR #77 --- torchtitan/experiments/auto_parallel/parallelize_llama.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 39ca4ab1da..2d3a3e2e2c 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -64,7 +64,7 @@ def input_fn(): param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - with AutoParallel(model, input_fn, world_mesh, mp_policy=mp_policy) as autop: + with AutoParallel(model, input_fn, world_mesh, mp_policy=mp_policy, compile=job_config.training.compile) as autop: autop.add_parameter_memory_constraint(low=None, high=None) possible_input_shardings = { @@ -87,8 +87,4 @@ def input_fn(): logger.info(f"AutoParallel took {t1 - t0} seconds") parallel_mod = autop.apply_placement(sharding_placement) - if job_config.training.compile: - torch._inductor.config.reorder_for_peak_memory = False - parallel_mod.compile(fullgraph=True) - return parallel_mod From 4712163eb9d0ac6709711de8aa708abd7f8d38d3 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 8 Aug 2025 15:29:37 +0200 Subject: [PATCH 14/49] Fix bucket sizes for AutoParallel 1D (#1545) This PR makes bucket sizes for all-gather and reduce-scatter to be of the same size for 1d FSDP. --- .../auto_parallel/parallelize_llama.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 2d3a3e2e2c..d003cb2e5f 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -41,19 +41,28 @@ def input_fn(): # step. dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard global_batch_size = job_config.training.local_batch_size * dp_degree - return torch.randint( - 0, - # job_config.training.vocab_size, - model.vocab_size, - (global_batch_size, job_config.training.seq_len), - device=torch.device("cuda"), - ), + return ( + torch.randint( + 0, + # job_config.training.vocab_size, + model.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" assert parallel_dims.cp_enabled is False, "CP not supported yet" assert parallel_dims.pp_enabled is False, "PP not supported yet" + torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( + lambda bucket_idx: 500 / parallel_dims.tp + ) + torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = ( + lambda bucket_idx: 1000 / parallel_dims.tp + ) + # bail out # model = model_fn() # return model @@ -64,7 +73,13 @@ def input_fn(): param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - with AutoParallel(model, input_fn, world_mesh, mp_policy=mp_policy, compile=job_config.training.compile) as autop: + with AutoParallel( + model, + input_fn, + world_mesh, + mp_policy=mp_policy, + compile=job_config.training.compile, + ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) possible_input_shardings = { From 3f04d223f7f3bd98510c8e288ff72183ff6caa5b Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sun, 10 Aug 2025 19:12:30 +0200 Subject: [PATCH 15/49] Add support for loss parallel (#1546) IMO we should just add the loss in the model and let autoparallel parallelize it for us. But for now, let's follow how the other models are implemented --- .../auto_parallel/parallelize_llama.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index d003cb2e5f..49a8bc49ff 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -88,18 +88,48 @@ def input_fn(): "dp_shard": Shard(0), "tp": Replicate(), } + # only used if loss parallel is enabled + possible_output_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_shard": Shard(0), + "tp": Shard(2), + } assert all( name in possible_input_shardings for name in world_mesh.mesh_dim_names ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" x_sharding = tuple( possible_input_shardings[name] for name in world_mesh.mesh_dim_names ) + out_sharding = x_sharding + if parallel_dims.loss_parallel_enabled: + out_sharding = tuple( + possible_output_shardings[name] + for name in world_mesh.mesh_dim_names + if name != "dp_replicate" + ) autop.add_input_constraints([x_sharding]) - autop.add_output_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) t0 = time.time() sharding_placement = autop.optimize_placement() t1 = time.time() logger.info(f"AutoParallel took {t1 - t0} seconds") parallel_mod = autop.apply_placement(sharding_placement) + if parallel_dims.loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, world_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + return parallel_mod From 8e50870985239adb109733e599f389350706b03b Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 18 Aug 2025 15:51:30 -0700 Subject: [PATCH 16/49] Add config for running simple-fsdp bucketing/reordering passes just add `--experimental.enable_simplefsdp_passes` and do not try to combine it with other `bucket_*` or `reorder_*` options. --- torchtitan/config_manager.py | 2 ++ torchtitan/train.py | 58 ++++++++++++++++++++++++++---------- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 131d7d2bd8..efc9d7a951 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -694,6 +694,8 @@ class Experimental: autop_force_bf16: bool = False + enable_simplefsdp_passes: bool = False + @dataclass class Validation: enabled: bool = False diff --git a/torchtitan/train.py b/torchtitan/train.py index 8ca90037b0..1ebdd984ed 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -124,21 +124,49 @@ def __init__(self, job_config: JobConfig): torch._inductor.config.allow_buffer_reuse = False # allow configuring inductor comms optimizations from torchtitan commandline - torch._inductor.config.bucket_all_gathers_fx = ( - job_config.experimental.bucket_all_gathers_fx - ) - torch._inductor.config.bucket_reduce_scatters_fx = ( - job_config.experimental.bucket_reduce_scatters_fx - ) - torch._inductor.config.reorder_for_compute_comm_overlap = ( - job_config.experimental.reorder_for_compute_comm_overlap - ) - torch._inductor.config.reorder_for_compute_comm_overlap_passes = ( - job_config.experimental.reorder_for_compute_comm_overlap_passes - ) - torch._inductor.config.reorder_prefetch_limit = ( - job_config.experimental.reorder_prefetch_limit - ) + if job_config.experimental.enable_simplefsdp_passes: + try: + from torch._inductor.simple_fsdp.bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir + except ImportError: + print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c") + raise + + # Configs from Ruisi + + # set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives + torch._inductor.config.simplefsdp.relax_ratio = 0 + torch._inductor.config.allow_buffer_reuse = False + torch._inductor.config.simplefsdp.estimate_ir = False + torch._inductor.config.simplefsdp.estimate_verbose = False + torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl" + # set to True after the first communication estimation results are saved. This would reduce decision making time. + torch._inductor.config.simplefsdp.load_cache = False + torch._inductor.config.simplefsdp.enable_bucket_ir = True + torch._inductor.config.simplefsdp.enable_reorder_ir = True + torch._inductor.config.simplefsdp.simplefsdp_only = True # False for 2d True for 1d + torch._inductor.config.simplefsdp.peak_memory_offset = 0 + torch._inductor.config.simplefsdp.bucketing_type = "auto" + + # Don't use both sets of passes at the same time! + torch._inductor.config.bucket_all_gathers_fx = "none" + torch._inductor.config.bucket_reduce_scatters_fx = "none" + torch._inductor.config.reorder_for_compute_comm_overlap = False + else: + torch._inductor.config.bucket_all_gathers_fx = ( + job_config.experimental.bucket_all_gathers_fx + ) + torch._inductor.config.bucket_reduce_scatters_fx = ( + job_config.experimental.bucket_reduce_scatters_fx + ) + torch._inductor.config.reorder_for_compute_comm_overlap = ( + job_config.experimental.reorder_for_compute_comm_overlap + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = ( + job_config.experimental.reorder_for_compute_comm_overlap_passes + ) + torch._inductor.config.reorder_prefetch_limit = ( + job_config.experimental.reorder_prefetch_limit + ) # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). From 91c56397e8f7c98ec2c685bdf551637f8780b56c Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 19 Aug 2025 13:46:38 -0700 Subject: [PATCH 17/49] Hook up deepseekv3_auto_parallel This command should now run `CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel` However it doesn't actually do anything with autoparallel yet. Next step is to attach local_map to the model so that autoparallel can run. --- .../experiments/auto_parallel/__init__.py | 18 +++ .../auto_parallel/parallelize_deepseekv3.py | 135 ++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py diff --git a/torchtitan/experiments/auto_parallel/__init__.py b/torchtitan/experiments/auto_parallel/__init__.py index 8f5a876b4e..d707d43088 100644 --- a/torchtitan/experiments/auto_parallel/__init__.py +++ b/torchtitan/experiments/auto_parallel/__init__.py @@ -13,7 +13,11 @@ from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer from torchtitan.protocols.train_spec import register_train_spec, TrainSpec +from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers +from torchtitan.models.deepseek_v3 import deepseekv3_configs, DeepSeekV3Model from .parallelize_llama import parallelize_llama +from .parallelize_deepseekv3 import parallelize_deepseekv3 + register_train_spec( TrainSpec( @@ -29,3 +33,17 @@ build_loss_fn=build_cross_entropy_loss, ) ) +register_train_spec( + TrainSpec( + name="deepseekv3_auto_parallel", + cls=DeepSeekV3Model, + config=deepseekv3_configs, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=None, + build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py new file mode 100644 index 0000000000..946ec8a199 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch + +from autoparallel.api import AutoParallel + +from torch.distributed import DeviceMesh +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard + +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +def parallelize_deepseekv3( + model, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return ( + torch.randint( + 0, + # job_config.training.vocab_size, + model.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + # torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( + # lambda bucket_idx: 500 / parallel_dims.tp + # ) + # torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = ( + # lambda bucket_idx: 1000 / parallel_dims.tp + # ) + + # bail out + return model + + # if job_config.experimental.autop_force_bf16: + # logger.info("Forcing bf16 on model") + # model = model.bfloat16() + + # param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] + # reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] + # mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + # with AutoParallel( + # model, + # input_fn, + # world_mesh, + # mp_policy=mp_policy, + # compile=job_config.training.compile, + # ) as autop: + # autop.add_parameter_memory_constraint(low=None, high=None) + + # possible_input_shardings = { + # # maps relative to mesh dim names used in torchtitan + # "dp_replicate": Shard(0), + # "dp_shard": Shard(0), + # "tp": Replicate(), + # } + # # only used if loss parallel is enabled + # possible_output_shardings = { + # # maps relative to mesh dim names used in torchtitan + # "dp_shard": Shard(0), + # "tp": Shard(2), + # } + # assert all( + # name in possible_input_shardings for name in world_mesh.mesh_dim_names + # ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + # x_sharding = tuple( + # possible_input_shardings[name] for name in world_mesh.mesh_dim_names + # ) + # out_sharding = x_sharding + # if parallel_dims.loss_parallel_enabled: + # out_sharding = tuple( + # possible_output_shardings[name] + # for name in world_mesh.mesh_dim_names + # if name != "dp_replicate" + # ) + # autop.add_input_constraints([x_sharding]) + # autop.add_output_constraints([out_sharding]) + # t0 = time.time() + # sharding_placement = autop.optimize_placement() + # t1 = time.time() + # logger.info(f"AutoParallel took {t1 - t0} seconds") + # parallel_mod = autop.apply_placement(sharding_placement) + + if parallel_dims.loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, world_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + + return parallel_mod From 1233902a54e88851f4381349d6df1ecb67134ba7 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 19 Aug 2025 15:49:44 -0700 Subject: [PATCH 18/49] [dsv3] patch graph break fix, works up until sharding rules --- .../auto_parallel/parallelize_deepseekv3.py | 24 +-- torchtitan/models/deepseek_v3/model/moe.py | 140 +++++++++--------- 2 files changed, 81 insertions(+), 83 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py index 946ec8a199..7ef9110acc 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py @@ -44,8 +44,7 @@ def input_fn(): return ( torch.randint( 0, - # job_config.training.vocab_size, - model.vocab_size, + model.model_args.vocab_size, (global_batch_size, job_config.training.seq_len), device=torch.device("cuda"), ), @@ -63,9 +62,6 @@ def input_fn(): # lambda bucket_idx: 1000 / parallel_dims.tp # ) - # bail out - return model - # if job_config.experimental.autop_force_bf16: # logger.info("Forcing bf16 on model") # model = model.bfloat16() @@ -73,13 +69,17 @@ def input_fn(): # param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] # reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] # mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - # with AutoParallel( - # model, - # input_fn, - # world_mesh, - # mp_policy=mp_policy, - # compile=job_config.training.compile, - # ) as autop: + mp_policy = None + with AutoParallel( + model, + input_fn, + world_mesh, + mp_policy=mp_policy, + compile=job_config.training.compile, + ) as autop: + # currently errors due to missing sharding prop rules + torch.distributed.breakpoint() + # autop.add_parameter_memory_constraint(low=None, high=None) # possible_input_shardings = { diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index 2554d61310..86408f82c5 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -48,6 +48,73 @@ def init_weights(self, init_std: float = 0.02): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) +# TODO: keeping this for-loop implementation for comparison +# and readability, may remove later +@expert_parallel +def _run_experts_for_loop( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, +) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = F.silu(torch.matmul(x_expert, w1[expert_idx])) + h = h * torch.matmul(x_expert, w3[expert_idx]) + h = torch.matmul(h, w2[expert_idx]) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, w1)) + h = h * torch.bmm(x, w3) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, w2) + + return out + +@expert_parallel +def _run_experts_grouped_mm( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, +) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + else: + offsets = None + # fall back to regular bmm between 3D tensors + assert x.dim() == 3 + + h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) + h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) + out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) + + return out + class GroupedExperts(nn.Module): def __init__( self, @@ -69,83 +136,14 @@ def forward( num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: if self.use_grouped_mm: - return GroupedExperts._run_experts_grouped_mm( + return _run_experts_grouped_mm( self.w1, self.w2, self.w3, x, num_tokens_per_expert ) else: - return GroupedExperts._run_experts_for_loop( + return _run_experts_for_loop( self.w1, self.w2, self.w3, x, num_tokens_per_expert ) - # TODO: keeping this for-loop implementation for comparison - # and readability, may remove later - @expert_parallel - @staticmethod - def _run_experts_for_loop( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - # NOTE: this would incur a synchronization between device and host - num_tokens_per_expert = num_tokens_per_expert.tolist() - - # side-effect code due to the usage of generate_permute_indices - num_padding = x.shape[0] - sum(num_tokens_per_expert) - - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - h = F.silu(torch.matmul(x_expert, w1[expert_idx])) - h = h * torch.matmul(x_expert, w3[expert_idx]) - h = torch.matmul(h, w2[expert_idx]) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - - # side-effect code due to the usage of generate_permute_indices - out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) - else: - # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, w1)) - h = h * torch.bmm(x, w3) - # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, w2) - - return out - - @expert_parallel - @staticmethod - def _run_experts_grouped_mm( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # grouped mm between a 2D tensor and a 3D tensor - assert x.dim() == 2 - else: - offsets = None - # fall back to regular bmm between 3D tensors - assert x.dim() == 3 - - h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) - h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) - out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) - - return out - def init_weights(self, init_std: float): nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) From 4f8677b79b14d4157b01dde9c26a2c1ac0f8a695 Mon Sep 17 00:00:00 2001 From: ruisizhang123 Date: Thu, 21 Aug 2025 10:20:47 -0700 Subject: [PATCH 19/49] update simplefsdp pass config --- torchtitan/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 1ebdd984ed..e4da1b1f6b 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -131,7 +131,7 @@ def __init__(self, job_config: JobConfig): print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c") raise - # Configs from Ruisi + # Configs from Ruisi # set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives torch._inductor.config.simplefsdp.relax_ratio = 0 @@ -140,10 +140,10 @@ def __init__(self, job_config: JobConfig): torch._inductor.config.simplefsdp.estimate_verbose = False torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl" # set to True after the first communication estimation results are saved. This would reduce decision making time. - torch._inductor.config.simplefsdp.load_cache = False + torch._inductor.config.simplefsdp.load_cache = False torch._inductor.config.simplefsdp.enable_bucket_ir = True torch._inductor.config.simplefsdp.enable_reorder_ir = True - torch._inductor.config.simplefsdp.simplefsdp_only = True # False for 2d True for 1d + torch._inductor.config.simplefsdp.simplefsdp_only = False # False for 2d True for 1d torch._inductor.config.simplefsdp.peak_memory_offset = 0 torch._inductor.config.simplefsdp.bucketing_type = "auto" From 714cc5b395cac355a146fa66ab0b55e434c50f75 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 21 Aug 2025 23:29:37 -0700 Subject: [PATCH 20/49] [dsv3] disable MoE while we fix local_map, works up until optimizer --- .../auto_parallel/parallelize_deepseekv3.py | 71 +++++++++---------- torchtitan/models/deepseek_v3/model/model.py | 4 +- 2 files changed, 37 insertions(+), 38 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py index 7ef9110acc..efda3327c2 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py @@ -77,43 +77,40 @@ def input_fn(): mp_policy=mp_policy, compile=job_config.training.compile, ) as autop: - # currently errors due to missing sharding prop rules - torch.distributed.breakpoint() - - # autop.add_parameter_memory_constraint(low=None, high=None) - - # possible_input_shardings = { - # # maps relative to mesh dim names used in torchtitan - # "dp_replicate": Shard(0), - # "dp_shard": Shard(0), - # "tp": Replicate(), - # } - # # only used if loss parallel is enabled - # possible_output_shardings = { - # # maps relative to mesh dim names used in torchtitan - # "dp_shard": Shard(0), - # "tp": Shard(2), - # } - # assert all( - # name in possible_input_shardings for name in world_mesh.mesh_dim_names - # ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" - # x_sharding = tuple( - # possible_input_shardings[name] for name in world_mesh.mesh_dim_names - # ) - # out_sharding = x_sharding - # if parallel_dims.loss_parallel_enabled: - # out_sharding = tuple( - # possible_output_shardings[name] - # for name in world_mesh.mesh_dim_names - # if name != "dp_replicate" - # ) - # autop.add_input_constraints([x_sharding]) - # autop.add_output_constraints([out_sharding]) - # t0 = time.time() - # sharding_placement = autop.optimize_placement() - # t1 = time.time() - # logger.info(f"AutoParallel took {t1 - t0} seconds") - # parallel_mod = autop.apply_placement(sharding_placement) + autop.add_parameter_memory_constraint(low=None, high=None) + + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + # only used if loss parallel is enabled + possible_output_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_shard": Shard(0), + "tp": Shard(2), + } + assert all( + name in possible_input_shardings for name in world_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in world_mesh.mesh_dim_names + ) + out_sharding = x_sharding + if parallel_dims.loss_parallel_enabled: + out_sharding = tuple( + possible_output_shardings[name] + for name in world_mesh.mesh_dim_names + if name != "dp_replicate" + ) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) if parallel_dims.loss_parallel_enabled: diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 61034e4c7d..946de75b58 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -270,7 +270,9 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.attention = Attention(model_args) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.moe_enabled = layer_id >= model_args.n_dense_layers + # self.moe_enabled = layer_id >= model_args.n_dense_layers + # TODO: enable me when local_map works + self.moe_enabled = False if self.moe_enabled: self.moe = MoE(model_args) From bfa9f7f9c414bcea16b62ea08d1fcdb7463e8937 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 5 Sep 2025 15:10:35 -0400 Subject: [PATCH 21/49] tweak ds3 model.py to reflect main branch for DS3 baseline can run (#1684) --- torchtitan/models/deepseek_v3/model/model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 8f8a512f63..9074919c99 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -270,9 +270,7 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.attention = Attention(model_args) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - # self.moe_enabled = layer_id >= model_args.n_dense_layers - # TODO: enable me when local_map works - self.moe_enabled = False + self.moe_enabled = layer_id >= model_args.n_dense_layers if self.moe_enabled: self.moe = MoE( From 75fb2eb4d0eb9438a7e6a7a0952032995f4f6e2a Mon Sep 17 00:00:00 2001 From: Ruisi Zhang Date: Sat, 6 Sep 2025 00:28:39 -0700 Subject: [PATCH 22/49] add simplefsdp's autobucketing pass entry (#1658) as titled, this pr adds entry to simplefsdp's autobucketing pass in autoparallel. original code is in: https://github.com/pytorch/pytorch/pull/160282 The main code for autobucketing pass will be added to autoparallel repo. --- .../experiments/auto_parallel/README.md | 4 ++ torchtitan/train.py | 38 +++++++++---------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md index ef66a59166..7e112329b9 100644 --- a/torchtitan/experiments/auto_parallel/README.md +++ b/torchtitan/experiments/auto_parallel/README.md @@ -4,4 +4,8 @@ requires installing git@github.com:pytorch-labs/autoparallel.git `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` +Use simplefsdp's autobucketing pass: + +`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --experimental.enable_simplefsdp_passes --compile.enable` + (or llama3-8b.toml) diff --git a/torchtitan/train.py b/torchtitan/train.py index 2829aa3c55..b69b74faac 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -8,6 +8,7 @@ import os import time from datetime import timedelta +from functools import partial from typing import Any, Generator, Iterable, Optional import torch @@ -130,32 +131,29 @@ def __init__(self, job_config: JobConfig): # allow configuring inductor comms optimizations from torchtitan commandline if job_config.experimental.enable_simplefsdp_passes: - try: - from torch._inductor.simple_fsdp.bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir - except ImportError: - print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c") - raise - - # Configs from Ruisi + # enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282) + from autoparallel.auto_bucketing import ( + simple_fsdp_autobucketing_reordering_pass, + simplefsdp_autobucketing_config, + ) - # set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives - torch._inductor.config.simplefsdp.relax_ratio = 0 torch._inductor.config.allow_buffer_reuse = False - torch._inductor.config.simplefsdp.estimate_ir = False - torch._inductor.config.simplefsdp.estimate_verbose = False - torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl" - # set to True after the first communication estimation results are saved. This would reduce decision making time. - torch._inductor.config.simplefsdp.load_cache = False - torch._inductor.config.simplefsdp.enable_bucket_ir = True - torch._inductor.config.simplefsdp.enable_reorder_ir = True - torch._inductor.config.simplefsdp.simplefsdp_only = False # False for 2d True for 1d - torch._inductor.config.simplefsdp.peak_memory_offset = 0 - torch._inductor.config.simplefsdp.bucketing_type = "auto" + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = True + simplefsdp_autobucketing_config.save_estimation_path = ( + "/tmp/torchtitan_simplefsdp_comm_estimation.pkl" + ) + simple_fsdp_autobucketing_reordering_pass = partial( + simple_fsdp_autobucketing_reordering_pass, + configs=simplefsdp_autobucketing_config, + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ + simple_fsdp_autobucketing_reordering_pass + ] # Don't use both sets of passes at the same time! torch._inductor.config.bucket_all_gathers_fx = "none" torch._inductor.config.bucket_reduce_scatters_fx = "none" - torch._inductor.config.reorder_for_compute_comm_overlap = False else: torch._inductor.config.bucket_all_gathers_fx = ( job_config.experimental.bucket_all_gathers_fx From 8769396256c5ad8ba9e9ea15787a32b2bdb8bd91 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 11 Sep 2025 12:26:22 -0700 Subject: [PATCH 23/49] [dsv3] 1D AP w/ local_map --- torchtitan/components/optimizer.py | 43 ++++-- .../auto_parallel/parallelize_deepseekv3.py | 95 +++++++++++- torchtitan/models/moe.py | 137 ++++++++++-------- torchtitan/train.py | 2 +- 4 files changed, 202 insertions(+), 75 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index d3e9628103..1cadefbf71 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -351,12 +351,35 @@ def _update_expert_bias( dp_cp_mesh = ( parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None ) + + ################################################################3 + # AP friendly methods + + def is_moe_block(block): + moe_enabled = getattr(block, "moe_enabled", False) + has_moe_submod = hasattr(block, "moe") # AP + return moe_enabled or has_moe_submod + + def get_transformer_blocks(model_part): + if isinstance(model_part.layers, nn.ModuleDict): + # regular torchtitan + blocks = model_part.layers.values() + else: + # TODO: fix autoparallel to preserve the module dict + blocks = model_part.layers.children() + return blocks + + def should_manual_allreduce(tokens_per_expert_by_layer): + return not isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor) + ################################################################3 + # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. tokens_per_expert_list = [] for model_part in model_parts: - for transformer_block in model_part.layers.values(): - if not transformer_block.moe_enabled: + blocks = get_transformer_blocks(model_part) + for transformer_block in blocks: + if not is_moe_block(transformer_block): continue if transformer_block.moe.load_balance_coeff is None: return @@ -372,17 +395,19 @@ def _update_expert_bias( tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list) if dp_cp_mesh is not None: - # Perform single all-reduce to get global statistics across all processes - pg = dp_cp_mesh.get_group() - torch.distributed.all_reduce( - tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM - ) + if should_manual_allreduce(tokens_per_expert_by_layer): + # Perform single all-reduce to get global statistics across all processes + pg = dp_cp_mesh.get_group() + torch.distributed.all_reduce( + tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM + ) moe_layer_idx = 0 with torch.no_grad(): for model_part in model_parts: - for transformer_block in model_part.layers.values(): - if not transformer_block.moe_enabled: + blocks = get_transformer_blocks(model_part) + for transformer_block in blocks: + if not is_moe_block(transformer_block): continue moe = transformer_block.moe diff --git a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py index c1991dfa5a..cf69511e0a 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py @@ -19,14 +19,59 @@ from torchtitan.tools.logging import logger +def apply_local_map_to_moe(): + """ + TODO: fix HOPs not restoring the original signature. + TODO: fix tracing with local shapes so that we can use Shard placements + + Current HOP signature we get: + + class subgraph_0(torch.nn.Module): + def forward(self, + rms_norm_5: "f32[64, 2048, 256][524288, 256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__router____modules__gate____parameters__weight: "f32[8, 256][256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____buffers__expert_bias: "f32[8][1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w1: "f32[8, 256, 256][65536, 256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w3: "f32[8, 256, 256][65536, 256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w2: "f32[8, 256, 256][65536, 256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w1____parameters__weight: "f32[512, 256][256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w3____parameters__weight: "f32[512, 256][256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w2____parameters__weight: "f32[256, 512][512, 1]cuda:0"): + """ + from torchtitan.models import moe + from torch.distributed._tensor.experimental import local_map + moe._moe_forward = local_map( + moe._moe_forward, + out_placements=( + (Replicate(),), # (Shard(0),), + (Replicate(),), + ), + in_placements=( + (Replicate(),), # (Shard(0),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + ), + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=None, + ) + + +# Run workflow with: +# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel def parallelize_deepseekv3( model, parallel_dims: ParallelDims, job_config: JobConfig, ): """ - Apply tensor parallelism, activation checkpointing, torch.compile, and data - parallelism to the model. + Apply Autoparallel to the model NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. @@ -54,6 +99,9 @@ def input_fn(): assert parallel_dims.cp_enabled is False, "CP not supported yet" assert parallel_dims.pp_enabled is False, "PP not supported yet" + # apply local_map to MoE + apply_local_map_to_moe() + # torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( # lambda bucket_idx: 500 / parallel_dims.tp # ) @@ -131,4 +179,47 @@ def _return_as_dtensor_for_loss_parallel(module, args, output): # removing it at any point parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + _preserve_moe_attributes(model, parallel_mod) + return parallel_mod + + +def _preserve_moe_attributes(original_model, parallel_model): + """ + Preserve MoE custom attributes from the original model to the parallel model. + This is only needed for attributes that aren't used in the graph, so they aren't + lifted as graph inputs and fetched by the pre-graph runtime wrapper. + + `moe_enabled` ane `load_balance_coeff` are used later in the optimizer to identify + this block as a moe block. This should be safe as they are read-only. + """ + def get_moe_modules(model): + """Extract all MoE modules from the model.""" + moe_modules = [] + if hasattr(model, 'layers'): + if isinstance(model.layers, torch.nn.ModuleDict): + # regular torchtitan structure + blocks = model.layers.values() + else: + # autoparallel might change structure + blocks = model.layers.children() if hasattr(model.layers, 'children') else [] + + for block in blocks: + if hasattr(block, 'moe_enabled') and block.moe_enabled and hasattr(block, 'moe'): + moe_modules.append(block.moe) + elif hasattr(block, 'moe'): # fallback for autoparallel + moe_modules.append(block.moe) + return moe_modules + + original_moe_modules = get_moe_modules(original_model) + parallel_moe_modules = get_moe_modules(parallel_model) + + # Copy custom attributes from original to parallel MoE modules + # This is fine to do since these attributes are read only + for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules): + if hasattr(orig_moe, 'moe_enabled'): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff + + # Copy load_balance_coeff + if hasattr(orig_moe, 'load_balance_coeff'): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 8be14ecbf0..1ec8e3b23b 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -12,6 +12,7 @@ from torch import nn from torchtitan.distributed.expert_parallel import expert_parallel +from torch.distributed.tensor.placement_types import Shard, Replicate @dataclass @@ -310,6 +311,77 @@ def forward( num_tokens_per_expert, ) +def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, experts, shared_experts): + # x: 64, 2048, 256 + bs, slen, dim = x.shape + x = x.view(-1, dim) + + # top_scores and selected_experts_indices shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + selected_experts_indices, + num_tokens_per_expert, + ) = router(x, expert_bias) + + # tokens_per_expert will be used to update the expert bias for load balancing. + # and also to count the expert usage + # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- + # first in the forward pass, and then in the backward pass. However, this has no + # effect on the expert bias update thanks to the torch.sign() operator. + # moved out to remove mutation + # with torch.no_grad(): + # tokens_per_expert.add_(num_tokens_per_expert) + + # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + # NOTE: the reason we need to compute num_tokens_per_expert again is: + # 1st computation in router is to update self.tokens_per_expert + # which would be the same across all TP ranks. + # 2nd computation in reorderer is for the actual routing and experts computation + # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. + # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. + ( + top_scores_experts_sorted, + token_indices_experts_sorted, + num_tokens_per_expert, + ) = reorderer(top_scores, selected_experts_indices) + + # shape (bs*slen*top_k, dim) + token_indices_experts_sorted = token_indices_experts_sorted.reshape( + -1, 1 + ).expand(-1, dim) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) + + if score_before_experts: + routed_input = ( + routed_input.to(torch.float32) + * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) + + # shape (bs*slen*top_k, dim) + routed_output = experts(routed_input, num_tokens_per_expert) + + if not score_before_experts: + routed_output = ( + routed_output.to(torch.float32) + * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) + + # shared expert + if shared_experts is not None: + out = shared_experts(x) + else: + out = torch.zeros_like(x) + + out = out.scatter_add( + dim=0, index=token_indices_experts_sorted, src=routed_output + ) + out = out.reshape(bs, slen, dim) + return out, num_tokens_per_expert + class MoE(nn.Module): def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): @@ -367,72 +439,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ - bs, slen, dim = x.shape - x = x.view(-1, dim) - - # top_scores and selected_experts_indices shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) - ( - top_scores, - selected_experts_indices, - num_tokens_per_expert, - ) = self.router(x, self.expert_bias) + out, num_tokens_per_expert = _moe_forward(x, self.router, self.expert_bias, self.reorderer, self.score_before_experts, self.experts, self.shared_experts) - # tokens_per_expert will be used to update the expert bias for load balancing. - # and also to count the expert usage - # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- - # first in the forward pass, and then in the backward pass. However, this has no - # effect on the expert bias update thanks to the torch.sign() operator. + # HOPs don't support buffer mutations, keep this outside with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) - - # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) - # NOTE: the reason we need to compute num_tokens_per_expert again is: - # 1st computation in router is to update self.tokens_per_expert - # which would be the same across all TP ranks. - # 2nd computation in reorderer is for the actual routing and experts computation - # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. - # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. - ( - top_scores_experts_sorted, - token_indices_experts_sorted, - num_tokens_per_expert, - ) = self.reorderer(top_scores, selected_experts_indices) - - # shape (bs*slen*top_k, dim) - token_indices_experts_sorted = token_indices_experts_sorted.reshape( - -1, 1 - ).expand(-1, dim) - - # shape (bs*slen*top_k, dim) - routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) - - if self.score_before_experts: - routed_input = ( - routed_input.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) - - # shape (bs*slen*top_k, dim) - routed_output = self.experts(routed_input, num_tokens_per_expert) - - if not self.score_before_experts: - routed_output = ( - routed_output.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) - - # shared expert - if self.shared_experts is not None: - out = self.shared_experts(x) - else: - out = torch.zeros_like(x) - - out = out.scatter_add( - dim=0, index=token_indices_experts_sorted, src=routed_output - ) - out = out.reshape(bs, slen, dim) return out def init_weights( diff --git a/torchtitan/train.py b/torchtitan/train.py index b69b74faac..3396fa56d2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -307,7 +307,7 @@ def __init__(self, job_config: JobConfig): # confirm that user will be able to view loss metrics on the console ensure_pp_loss_visible(parallel_dims, job_config, color) else: - # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + # apply Autoparallel model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) model.to_empty(device=init_device) From db224791abca94707ead46efe4de4d776ec37e7d Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Wed, 17 Sep 2025 16:00:36 -0700 Subject: [PATCH 24/49] [dsv3] Turn off Flex for AP --- torchtitan/models/deepseek_v3/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 1c3d2b19d2..bf3232fd04 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -100,8 +100,8 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( vocab_size=102400, @@ -127,8 +127,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", ), "671B": DeepSeekV3ModelArgs( vocab_size=129280, @@ -154,8 +154,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", ), } From 9dc0bd8c71ab32b36cb262f5ccffd8e1d97a7a1a Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Mon, 27 Oct 2025 16:29:06 -0700 Subject: [PATCH 25/49] Update to new model registration API --- torchtitan/experiments/__init__.py | 3 +- .../experiments/auto_parallel/__init__.py | 54 ------------------- .../auto_parallel/deepseek_v3/__init__.py | 36 +++++++++++++ .../parallelize_deepseekv3.py | 2 +- .../auto_parallel/llama3/__init__.py | 39 ++++++++++++++ .../{ => llama3}/parallelize_llama.py | 0 6 files changed, 78 insertions(+), 56 deletions(-) delete mode 100644 torchtitan/experiments/auto_parallel/__init__.py create mode 100644 torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py rename torchtitan/experiments/auto_parallel/{ => deepseek_v3}/parallelize_deepseekv3.py (99%) create mode 100644 torchtitan/experiments/auto_parallel/llama3/__init__.py rename torchtitan/experiments/auto_parallel/{ => llama3}/parallelize_llama.py (100%) diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 9bb3f101e1..044f9cc80d 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -12,6 +12,7 @@ "simple_fsdp.deepseek_v3", "vlm", "compiler_toolkit.deepseek_v3", - "autoparallel", + "auto_parallel.llama3", + "auto_parallel.deepseek_v3", ] ) diff --git a/torchtitan/experiments/auto_parallel/__init__.py b/torchtitan/experiments/auto_parallel/__init__.py deleted file mode 100644 index a67dfe18aa..0000000000 --- a/torchtitan/experiments/auto_parallel/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -# -# Copyright (c) Meta Platforms, Inc. All Rights Reserved. - -from torchtitan.components.loss import build_cross_entropy_loss -from torchtitan.components.lr_scheduler import build_lr_schedulers -from torchtitan.components.optimizer import build_optimizers, build_optimizers_with_moe_load_balancing -from torchtitan.components.validate import build_validator -from torchtitan.components.tokenizer import build_hf_tokenizer -from torchtitan.datasets.hf_datasets import build_hf_dataloader -from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer -from torchtitan.models.llama3.model.state_dict_adapter import Llama3StateDictAdapter -from torchtitan.models.deepseek_v3.model.state_dict_adapter import DeepSeekV3StateDictAdapter -from torchtitan.protocols.train_spec import register_train_spec, TrainSpec -from torchtitan.models.deepseek_v3 import deepseekv3_configs, DeepSeekV3Model -from .parallelize_llama import parallelize_llama -from .parallelize_deepseekv3 import parallelize_deepseekv3 - - -register_train_spec( - TrainSpec( - name="llama3_auto_parallel", - model_cls=Transformer, - model_args=llama3_configs, - parallelize_fn=parallelize_llama, - pipelining_fn=pipeline_llama, - build_optimizers_fn=build_optimizers, - build_lr_schedulers_fn=build_lr_schedulers, - build_dataloader_fn=build_hf_dataloader, - build_tokenizer_fn=build_hf_tokenizer, - build_loss_fn=build_cross_entropy_loss, - build_validator_fn=build_validator, - state_dict_adapter=Llama3StateDictAdapter, - ) -) -register_train_spec( - TrainSpec( - name="deepseekv3_auto_parallel", - model_cls=DeepSeekV3Model, - model_args=deepseekv3_configs, - parallelize_fn=parallelize_deepseekv3, - pipelining_fn=pipeline_llama, - build_optimizers_fn=build_optimizers_with_moe_load_balancing, - build_lr_schedulers_fn=build_lr_schedulers, - build_dataloader_fn=build_hf_dataloader, - build_tokenizer_fn=build_hf_tokenizer, - build_loss_fn=build_cross_entropy_loss, - state_dict_adapter=DeepSeekV3StateDictAdapter, - ) -) diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py b/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py new file mode 100644 index 0000000000..eb803b9300 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.models.moe import MoEArgs +from torchtitan.protocols.train_spec import TrainSpec + +from torchtitan.models.deepseek_v3 import deepseekv3_args, DeepSeekV3ModelArgs, DeepSeekV3Model +from torchtitan.models.deepseek_v3.model.state_dict_adapter import DeepSeekV3StateDictAdapter + +from .parallelize_deepseekv3 import parallelize_deepseekv3 + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=DeepSeekV3Model, + model_args=deepseekv3_args, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=DeepSeekV3StateDictAdapter, + ) diff --git a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py similarity index 99% rename from torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py rename to torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py index cf69511e0a..9bec6a0e17 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py @@ -38,7 +38,7 @@ def forward(self, self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w3____parameters__weight: "f32[512, 256][256, 1]cuda:0", self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w2____parameters__weight: "f32[256, 512][512, 1]cuda:0"): """ - from torchtitan.models import moe + from torchtitan.models.moe import moe from torch.distributed._tensor.experimental import local_map moe._moe_forward = local_map( moe._moe_forward, diff --git a/torchtitan/experiments/auto_parallel/llama3/__init__.py b/torchtitan/experiments/auto_parallel/llama3/__init__.py new file mode 100644 index 0000000000..c387fede59 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/llama3/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.protocols.train_spec import TrainSpec + +from torchtitan.models.llama3 import llama3_args, TransformerModelArgs, Transformer +from torchtitan.models.llama3.model.state_dict_adapter import Llama3StateDictAdapter + +from .parallelize_llama import parallelize_llama + + +# CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.llama3 --parallelism.tensor_parallel_degree 4 + +def get_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=Transformer, + model_args=llama3_args, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + state_dict_adapter=Llama3StateDictAdapter, + ) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py similarity index 100% rename from torchtitan/experiments/auto_parallel/parallelize_llama.py rename to torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py From c6e25bd3e28e4ee81e029e62b7178aa68ec1e511 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 5 Nov 2025 16:44:09 -0800 Subject: [PATCH 26/49] Whc/knobs (#1994) needs to merge in lock step with https://github.com/meta-pytorch/autoparallel/pull/233 --- torchtitan/config/job_config.py | 30 ++--------------------- torchtitan/train.py | 42 ++------------------------------- 2 files changed, 4 insertions(+), 68 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 6861a7f1a8..2ec8b84e03 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -857,37 +857,11 @@ class Experimental: needs to ensure that the path can be imported. """ - # "none", "all", "only_fsdp" - bucket_all_gathers_fx: str = "none" - - # "none", "all" - bucket_reduce_scatters_fx: str = "none" - - reorder_for_compute_comm_overlap: bool = False - """ - Whether to enable inductor comm reordering passes - """ - - reorder_for_compute_comm_overlap_passes: list[str] = field( - default_factory=lambda: [ - "sink_waits_iterative", - "reorder_communication_preserving_peak_memory", - ] - ) - """ - Sequence of reordering passes (names of functions inside _inductor.comms) to call, - if reorder_for_compute_comm_overlap is enabled. - """ - - reorder_prefetch_limit: int | None = None - """ - How many ops to allow moving any individual collective, if 'reorder_communication_preserving_peak_memory' - pass is enabled. default of None means unlimited - """ + # "aten" (default), "inductor", "none" + comms_bucket_reorder_strategy: str = "aten" autop_force_bf16: bool = False - enable_simplefsdp_passes: bool = False @dataclass class Validation: diff --git a/torchtitan/train.py b/torchtitan/train.py index 7d492f7249..61ffe4ec1a 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -33,6 +33,7 @@ maybe_enable_memory_snapshot, maybe_enable_profiling, ) +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing class Trainer(torch.distributed.checkpoint.stateful.Stateful): @@ -122,46 +123,7 @@ def __init__(self, job_config: JobConfig): torch._inductor.config.allow_buffer_reuse = False # allow configuring inductor comms optimizations from torchtitan commandline - if job_config.experimental.enable_simplefsdp_passes: - # enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282) - from autoparallel.auto_bucketing import ( - simple_fsdp_autobucketing_reordering_pass, - simplefsdp_autobucketing_config, - ) - - torch._inductor.config.allow_buffer_reuse = False - torch._inductor.config.reorder_for_peak_memory = False - torch._inductor.config.reorder_for_compute_comm_overlap = True - simplefsdp_autobucketing_config.save_estimation_path = ( - "/tmp/torchtitan_simplefsdp_comm_estimation.pkl" - ) - simple_fsdp_autobucketing_reordering_pass = partial( - simple_fsdp_autobucketing_reordering_pass, - configs=simplefsdp_autobucketing_config, - ) - torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ - simple_fsdp_autobucketing_reordering_pass - ] - - # Don't use both sets of passes at the same time! - torch._inductor.config.bucket_all_gathers_fx = "none" - torch._inductor.config.bucket_reduce_scatters_fx = "none" - else: - torch._inductor.config.bucket_all_gathers_fx = ( - job_config.experimental.bucket_all_gathers_fx - ) - torch._inductor.config.bucket_reduce_scatters_fx = ( - job_config.experimental.bucket_reduce_scatters_fx - ) - torch._inductor.config.reorder_for_compute_comm_overlap = ( - job_config.experimental.reorder_for_compute_comm_overlap - ) - torch._inductor.config.reorder_for_compute_comm_overlap_passes = ( - job_config.experimental.reorder_for_compute_comm_overlap_passes - ) - torch._inductor.config.reorder_prefetch_limit = ( - job_config.experimental.reorder_prefetch_limit - ) + configure_inductor_for_autobucketing(job_config.experimental.comms_bucket_reorder_strategy) # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). From e6ea814760e51e7751982a1aaef5632aba9f3d99 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 18 Nov 2025 09:10:08 -0800 Subject: [PATCH 27/49] lint --- torchtitan/components/optimizer.py | 13 ++--- .../auto_parallel/deepseek_v3/__init__.py | 9 ++-- .../deepseek_v3/parallelize_deepseekv3.py | 47 +++++++++---------- .../auto_parallel/llama3/__init__.py | 8 ++-- .../auto_parallel/llama3/parallelize_llama.py | 5 +- torchtitan/models/moe/moe.py | 31 +++++++----- torchtitan/tools/profiling.py | 6 ++- torchtitan/train.py | 12 ++--- 8 files changed, 71 insertions(+), 60 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 2328588279..49bc9bd534 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -360,12 +360,10 @@ def _update_expert_bias( parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None ) - ################################################################3 # AP friendly methods - def is_moe_block(block): moe_enabled = getattr(block, "moe_enabled", False) - has_moe_submod = hasattr(block, "moe") # AP + has_moe_submod = hasattr(block, "moe") # AP return moe_enabled or has_moe_submod def get_transformer_blocks(model_part): @@ -378,8 +376,9 @@ def get_transformer_blocks(model_part): return blocks def should_manual_allreduce(tokens_per_expert_by_layer): - return not isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor) - ################################################################3 + return not isinstance( + tokens_per_expert_by_layer, torch.distributed.tensor.DTensor + ) # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. @@ -407,7 +406,9 @@ def should_manual_allreduce(tokens_per_expert_by_layer): # Perform single all-reduce to get global statistics across all processes pg = dp_cp_mesh.get_group() torch.distributed.all_reduce( - tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM + tokens_per_expert_by_layer, + group=pg, + op=torch.distributed.ReduceOp.SUM, ) moe_layer_idx = 0 diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py b/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py index eb803b9300..ff220c496b 100644 --- a/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py @@ -12,11 +12,12 @@ from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.distributed.pipeline_parallel import pipeline_llm from torchtitan.hf_datasets.text_datasets import build_text_dataloader -from torchtitan.models.moe import MoEArgs -from torchtitan.protocols.train_spec import TrainSpec -from torchtitan.models.deepseek_v3 import deepseekv3_args, DeepSeekV3ModelArgs, DeepSeekV3Model -from torchtitan.models.deepseek_v3.model.state_dict_adapter import DeepSeekV3StateDictAdapter +from torchtitan.models.deepseek_v3 import deepseekv3_args, DeepSeekV3Model +from torchtitan.models.deepseek_v3.model.state_dict_adapter import ( + DeepSeekV3StateDictAdapter, +) +from torchtitan.protocols.train_spec import TrainSpec from .parallelize_deepseekv3 import parallelize_deepseekv3 diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py index 9bec6a0e17..39abb5f08a 100644 --- a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py @@ -10,10 +10,8 @@ from autoparallel.api import AutoParallel -from torch.distributed import DeviceMesh -from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard -from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger @@ -25,29 +23,18 @@ def apply_local_map_to_moe(): TODO: fix tracing with local shapes so that we can use Shard placements Current HOP signature we get: - - class subgraph_0(torch.nn.Module): - def forward(self, - rms_norm_5: "f32[64, 2048, 256][524288, 256, 1]cuda:0", - self____modules__layers____modules__1____modules__moe____modules__router____modules__gate____parameters__weight: "f32[8, 256][256, 1]cuda:0", - self____modules__layers____modules__1____modules__moe____buffers__expert_bias: "f32[8][1]cuda:0", - self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w1: "f32[8, 256, 256][65536, 256, 1]cuda:0", - self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w3: "f32[8, 256, 256][65536, 256, 1]cuda:0", - self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w2: "f32[8, 256, 256][65536, 256, 1]cuda:0", - self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w1____parameters__weight: "f32[512, 256][256, 1]cuda:0", - self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w3____parameters__weight: "f32[512, 256][256, 1]cuda:0", - self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w2____parameters__weight: "f32[256, 512][512, 1]cuda:0"): """ - from torchtitan.models.moe import moe from torch.distributed._tensor.experimental import local_map + from torchtitan.models.moe import moe + moe._moe_forward = local_map( moe._moe_forward, out_placements=( - (Replicate(),), # (Shard(0),), + (Replicate(),), # (Shard(0),), (Replicate(),), ), in_placements=( - (Replicate(),), # (Shard(0),), + (Replicate(),), # (Shard(0),), (Replicate(),), (Replicate(),), (Replicate(),), @@ -146,7 +133,8 @@ def input_fn(): ) out_sharding = x_sharding loss_parallel_enabled = ( - parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel ) if loss_parallel_enabled: out_sharding = tuple( @@ -190,24 +178,31 @@ def _preserve_moe_attributes(original_model, parallel_model): This is only needed for attributes that aren't used in the graph, so they aren't lifted as graph inputs and fetched by the pre-graph runtime wrapper. - `moe_enabled` ane `load_balance_coeff` are used later in the optimizer to identify + `moe_enabled` and `load_balance_coeff` are used later in the optimizer to identify this block as a moe block. This should be safe as they are read-only. """ + def get_moe_modules(model): """Extract all MoE modules from the model.""" moe_modules = [] - if hasattr(model, 'layers'): + if hasattr(model, "layers"): if isinstance(model.layers, torch.nn.ModuleDict): # regular torchtitan structure blocks = model.layers.values() else: # autoparallel might change structure - blocks = model.layers.children() if hasattr(model.layers, 'children') else [] + blocks = ( + model.layers.children() if hasattr(model.layers, "children") else [] + ) for block in blocks: - if hasattr(block, 'moe_enabled') and block.moe_enabled and hasattr(block, 'moe'): + if ( + hasattr(block, "moe_enabled") + and block.moe_enabled + and hasattr(block, "moe") + ): moe_modules.append(block.moe) - elif hasattr(block, 'moe'): # fallback for autoparallel + elif hasattr(block, "moe"): # fallback for autoparallel moe_modules.append(block.moe) return moe_modules @@ -217,9 +212,9 @@ def get_moe_modules(model): # Copy custom attributes from original to parallel MoE modules # This is fine to do since these attributes are read only for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules): - if hasattr(orig_moe, 'moe_enabled'): + if hasattr(orig_moe, "moe_enabled"): par_moe.load_balance_coeff = orig_moe.load_balance_coeff # Copy load_balance_coeff - if hasattr(orig_moe, 'load_balance_coeff'): + if hasattr(orig_moe, "load_balance_coeff"): par_moe.load_balance_coeff = orig_moe.load_balance_coeff diff --git a/torchtitan/experiments/auto_parallel/llama3/__init__.py b/torchtitan/experiments/auto_parallel/llama3/__init__.py index c387fede59..f9e61ddd7e 100644 --- a/torchtitan/experiments/auto_parallel/llama3/__init__.py +++ b/torchtitan/experiments/auto_parallel/llama3/__init__.py @@ -13,15 +13,17 @@ from torchtitan.components.validate import build_validator from torchtitan.distributed.pipeline_parallel import pipeline_llm from torchtitan.hf_datasets.text_datasets import build_text_dataloader -from torchtitan.protocols.train_spec import TrainSpec -from torchtitan.models.llama3 import llama3_args, TransformerModelArgs, Transformer +from torchtitan.models.llama3 import llama3_args, Transformer from torchtitan.models.llama3.model.state_dict_adapter import Llama3StateDictAdapter +from torchtitan.protocols.train_spec import TrainSpec from .parallelize_llama import parallelize_llama -# CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.llama3 --parallelism.tensor_parallel_degree 4 +# CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml +# ./run_train.sh --model.name auto_parallel.llama3 --parallelism.tensor_parallel_degree 4 + def get_train_spec() -> TrainSpec: return TrainSpec( diff --git a/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py b/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py index 6648f29ab8..d1a0009fc1 100644 --- a/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py @@ -10,7 +10,6 @@ from autoparallel.api import AutoParallel -from torch.distributed import DeviceMesh from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard @@ -33,6 +32,7 @@ def parallelize_llama( the model must fit on GPU or CPU memory. """ world_mesh = parallel_dims.world_mesh + def input_fn(): global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: @@ -101,7 +101,8 @@ def input_fn(): ) out_sharding = x_sharding loss_parallel_enabled = ( - parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel ) if loss_parallel_enabled: out_sharding = tuple( diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index b6cbd76eed..2f40615e30 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -353,7 +353,10 @@ def forward( num_tokens_per_expert, ) -def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, experts, shared_experts): + +def _moe_forward( + x, router, expert_bias, reorderer, score_before_experts, experts, shared_experts +): # x: 64, 2048, 256 bs, slen, dim = x.shape x = x.view(-1, dim) @@ -390,17 +393,16 @@ def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, expert ) = reorderer(top_scores, selected_experts_indices) # shape (bs*slen*top_k, dim) - token_indices_experts_sorted = token_indices_experts_sorted.reshape( - -1, 1 - ).expand(-1, dim) + token_indices_experts_sorted = token_indices_experts_sorted.reshape(-1, 1).expand( + -1, dim + ) # shape (bs*slen*top_k, dim) routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) if score_before_experts: routed_input = ( - routed_input.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) + routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) ).to(x.dtype) # shape (bs*slen*top_k, dim) @@ -408,8 +410,7 @@ def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, expert if not score_before_experts: routed_output = ( - routed_output.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) + routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) ).to(x.dtype) # shared expert @@ -418,9 +419,7 @@ def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, expert else: out = torch.zeros_like(x) - out = out.scatter_add( - dim=0, index=token_indices_experts_sorted, src=routed_output - ) + out = out.scatter_add(dim=0, index=token_indices_experts_sorted, src=routed_output) out = out.reshape(bs, slen, dim) return out, num_tokens_per_expert @@ -482,7 +481,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ - out, num_tokens_per_expert = _moe_forward(x, self.router, self.expert_bias, self.reorderer, self.score_before_experts, self.experts, self.shared_experts) + out, num_tokens_per_expert = _moe_forward( + x, + self.router, + self.expert_bias, + self.reorderer, + self.score_before_experts, + self.experts, + self.shared_experts, + ) # HOPs don't support buffer mutations, keep this outside with torch.no_grad(): diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 20a7d83273..7edfe66979 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -20,6 +20,7 @@ "https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html" ) + @contextlib.contextmanager def maybe_enable_profiling( profiling_config: ProfilingConfig, @@ -57,7 +58,10 @@ def trace_handler(prof): # but conveniently prints the internal url for perfetto on manifold for mast jobs manifold_mount_prefix = "/mnt/mffuse/" if output_file.find(manifold_mount_prefix) == 0: - manifold_path = os.path.join("torchtrain_datasets/tree", output_file.split(manifold_mount_prefix)[1]) + manifold_path = os.path.join( + "torchtrain_datasets/tree", + output_file.split(manifold_mount_prefix)[1], + ) perfetto_url = ( PERFETTO_UI_ROOT_URL + "#!/?url=https://interncache-all.fbcdn.net/manifold/" diff --git a/torchtitan/train.py b/torchtitan/train.py index fce33fee17..f0830db9fd 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -8,13 +8,12 @@ import os import time from datetime import timedelta -from functools import partial -from typing import Any, Generator, Iterable, Optional +from typing import Any, Generator, Iterable import torch +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing from torch.distributed.elastic.multiprocessing.errors import record -from torch.distributed.tensor import DTensor import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager @@ -34,7 +33,6 @@ maybe_enable_memory_snapshot, maybe_enable_profiling, ) -from autoparallel.auto_bucketing import configure_inductor_for_autobucketing class Trainer(torch.distributed.checkpoint.stateful.Stateful): @@ -108,13 +106,15 @@ def __init__(self, job_config: JobConfig): ) # TODO(whc) - # I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering + # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering torch._inductor.config.force_disable_caches = True # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. torch._inductor.config.allow_buffer_reuse = False # allow configuring inductor comms optimizations from torchtitan commandline - configure_inductor_for_autobucketing(job_config.experimental.comms_bucket_reorder_strategy) + configure_inductor_for_autobucketing( + job_config.experimental.comms_bucket_reorder_strategy + ) # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). From 7abede87af9faffa3e23f7812df53db67845297c Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 18 Nov 2025 11:50:49 -0800 Subject: [PATCH 28/49] undo moe patching --- .../deepseek_v3/parallelize_deepseekv3.py | 130 +++++++++++++++++- torchtitan/models/moe/moe.py | 95 ++----------- 2 files changed, 139 insertions(+), 86 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py index 39abb5f08a..2adfee062a 100644 --- a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import time +import types import torch @@ -17,7 +18,94 @@ from torchtitan.tools.logging import logger -def apply_local_map_to_moe(): +def _moe_forward( + x, router, expert_bias, reorderer, score_before_experts, experts, shared_experts +): + bs, slen, dim = x.shape + x = x.view(-1, dim) + + # top_scores and selected_experts_indices shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + selected_experts_indices, + num_tokens_per_expert, + ) = router(x, expert_bias) + num_tokens_per_expert_update = num_tokens_per_expert + + # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + # NOTE: the reason we need to compute num_tokens_per_expert again is: + # 1st computation in router is to update self.tokens_per_expert + # which would be the same across all TP ranks. + # 2nd computation in reorderer is for the actual routing and experts computation + # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. + # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. + ( + top_scores_experts_sorted, + token_indices_experts_sorted, + num_tokens_per_expert, + ) = reorderer(top_scores, selected_experts_indices) + + # shape (bs*slen*top_k, dim) + token_indices_experts_sorted = token_indices_experts_sorted.reshape(-1, 1).expand( + -1, dim + ) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) + + if score_before_experts: + routed_input = ( + routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) + + # shape (bs*slen*top_k, dim) + routed_output = experts(routed_input, num_tokens_per_expert) + + # shared expert + # Note: we execute the shared expert before scoring the output of the routed expert + # to "implicitly" overlap the shared expert compute with token combine communication + if shared_experts is not None: + out = shared_experts(x) + else: + out = torch.zeros_like(x) + + if not score_before_experts: + routed_output = ( + routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) + + out = out.scatter_add(dim=0, index=token_indices_experts_sorted, src=routed_output) + out = out.reshape(bs, slen, dim) + return out, num_tokens_per_expert_update + + +def moe_forward(self, x: torch.Tensor) -> torch.Tensor: + out, num_tokens_per_expert = _moe_forward( + x, + self.router, + self.expert_bias, + self.reorderer, + self.score_before_experts, + self.experts, + self.shared_experts, + ) + # HOPs don't support buffer mutations, keep this outside + # tokens_per_expert will be used to update the expert bias for load balancing. + # and also to count the expert usage + # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- + # first in the forward pass, and then in the backward pass. However, this has no + # effect on the expert bias update thanks to the torch.sign() operator. + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) + + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) + return out + + +def monkey_patch_local_map_moe(model, world_mesh): """ TODO: fix HOPs not restoring the original signature. TODO: fix tracing with local shapes so that we can use Shard placements @@ -25,10 +113,11 @@ def apply_local_map_to_moe(): Current HOP signature we get: """ from torch.distributed._tensor.experimental import local_map - from torchtitan.models.moe import moe - moe._moe_forward = local_map( - moe._moe_forward, + # from torchtitan.models.moe import moe + global _moe_forward + _moe_forward = local_map( + _moe_forward, out_placements=( (Replicate(),), # (Shard(0),), (Replicate(),), @@ -46,9 +135,38 @@ def apply_local_map_to_moe(): ), redistribute_inputs=True, in_grad_placements=None, - device_mesh=None, + device_mesh=world_mesh, ) + for block in model.layers.children(): + if not block.moe_enabled: + continue + block.moe.forward = types.MethodType(moe_forward, block.moe) + + # torch.distributed.breakpoint() + # moe.forward = moe_forward + # moe._moe_forward = local_map( + # moe._moe_forward, + # out_placements=( + # (Replicate(),), # (Shard(0),), + # (Replicate(),), + # ), + # in_placements=( + # (Replicate(),), # (Shard(0),), + # (Replicate(),), + # (Replicate(),), + # (Replicate(),), + # (Replicate(),), + # (Replicate(),), + # (Replicate(),), + # (Replicate(),), + # (Replicate(),), + # ), + # redistribute_inputs=True, + # in_grad_placements=None, + # device_mesh=mesh, + # ) + # Run workflow with: # CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel @@ -87,7 +205,7 @@ def input_fn(): assert parallel_dims.pp_enabled is False, "PP not supported yet" # apply local_map to MoE - apply_local_map_to_moe() + monkey_patch_local_map_moe(model, world_mesh) # torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( # lambda bucket_idx: 500 / parallel_dims.tp diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 2f40615e30..295e2193a5 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -354,76 +354,6 @@ def forward( ) -def _moe_forward( - x, router, expert_bias, reorderer, score_before_experts, experts, shared_experts -): - # x: 64, 2048, 256 - bs, slen, dim = x.shape - x = x.view(-1, dim) - - # top_scores and selected_experts_indices shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) - ( - top_scores, - selected_experts_indices, - num_tokens_per_expert, - ) = router(x, expert_bias) - - # tokens_per_expert will be used to update the expert bias for load balancing. - # and also to count the expert usage - # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- - # first in the forward pass, and then in the backward pass. However, this has no - # effect on the expert bias update thanks to the torch.sign() operator. - # moved out to remove mutation - # with torch.no_grad(): - # tokens_per_expert.add_(num_tokens_per_expert) - - # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) - # NOTE: the reason we need to compute num_tokens_per_expert again is: - # 1st computation in router is to update self.tokens_per_expert - # which would be the same across all TP ranks. - # 2nd computation in reorderer is for the actual routing and experts computation - # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. - # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. - ( - top_scores_experts_sorted, - token_indices_experts_sorted, - num_tokens_per_expert, - ) = reorderer(top_scores, selected_experts_indices) - - # shape (bs*slen*top_k, dim) - token_indices_experts_sorted = token_indices_experts_sorted.reshape(-1, 1).expand( - -1, dim - ) - - # shape (bs*slen*top_k, dim) - routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) - - if score_before_experts: - routed_input = ( - routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) - - # shape (bs*slen*top_k, dim) - routed_output = experts(routed_input, num_tokens_per_expert) - - if not score_before_experts: - routed_output = ( - routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) - - # shared expert - if shared_experts is not None: - out = shared_experts(x) - else: - out = torch.zeros_like(x) - - out = out.scatter_add(dim=0, index=token_indices_experts_sorted, src=routed_output) - out = out.reshape(bs, slen, dim) - return out, num_tokens_per_expert - - class MoE(nn.Module): def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): super().__init__() @@ -481,17 +411,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ - out, num_tokens_per_expert = _moe_forward( - x, - self.router, - self.expert_bias, - self.reorderer, - self.score_before_experts, - self.experts, - self.shared_experts, - ) + bs, slen, dim = x.shape + x = x.view(-1, dim) + + # top_scores and selected_experts_indices shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + selected_experts_indices, + num_tokens_per_expert, + ) = self.router(x, self.expert_bias) - # HOPs don't support buffer mutations, keep this outside + # tokens_per_expert will be used to update the expert bias for load balancing. + # and also to count the expert usage + # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- + # first in the forward pass, and then in the backward pass. However, this has no + # effect on the expert bias update thanks to the torch.sign() operator. with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) From d2e76b7d5bd2377f0763fea778a93ed5650d33d1 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 18 Nov 2025 12:56:54 -0800 Subject: [PATCH 29/49] move inductor config into experiment folders --- .../deepseek_v3/parallelize_deepseekv3.py | 13 +++++++++++++ .../auto_parallel/llama3/parallelize_llama.py | 13 +++++++++++++ torchtitan/train.py | 12 ------------ 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py index 2adfee062a..c826a8819a 100644 --- a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py @@ -10,6 +10,7 @@ import torch from autoparallel.api import AutoParallel +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing from torch.distributed.tensor.placement_types import Replicate, Shard from torchtitan.config import JobConfig @@ -181,6 +182,18 @@ def parallelize_deepseekv3( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ + + # TODO(whc) + # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False + + # allow configuring inductor comms optimizations from torchtitan commandline + configure_inductor_for_autobucketing( + job_config.experimental.comms_bucket_reorder_strategy + ) + world_mesh = parallel_dims.world_mesh def input_fn(): diff --git a/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py b/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py index d1a0009fc1..1d2bee4351 100644 --- a/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py @@ -9,6 +9,7 @@ import torch from autoparallel.api import AutoParallel +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard @@ -31,6 +32,18 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ + + # TODO(whc) + # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False + + # allow configuring inductor comms optimizations from torchtitan commandline + configure_inductor_for_autobucketing( + job_config.experimental.comms_bucket_reorder_strategy + ) + world_mesh = parallel_dims.world_mesh def input_fn(): diff --git a/torchtitan/train.py b/torchtitan/train.py index f0830db9fd..b01180787a 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,7 +11,6 @@ from typing import Any, Generator, Iterable import torch -from autoparallel.auto_bucketing import configure_inductor_for_autobucketing from torch.distributed.elastic.multiprocessing.errors import record @@ -105,17 +104,6 @@ def __init__(self, job_config: JobConfig): gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) - # TODO(whc) - # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering - torch._inductor.config.force_disable_caches = True - # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. - torch._inductor.config.allow_buffer_reuse = False - - # allow configuring inductor comms optimizations from torchtitan commandline - configure_inductor_for_autobucketing( - job_config.experimental.comms_bucket_reorder_strategy - ) - # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( From 472b4ad468315a2e747483457bd2a62e71f5bf0a Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 18 Nov 2025 16:03:44 -0800 Subject: [PATCH 30/49] fix local_map moe patch --- .../deepseek_v3/parallelize_deepseekv3.py | 187 ++++++++++++------ 1 file changed, 129 insertions(+), 58 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py index c826a8819a..66a5b71354 100644 --- a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py @@ -6,8 +6,11 @@ import time import types +from typing import Callable, Optional import torch +import torch.nn as nn +import torch.nn.functional as F from autoparallel.api import AutoParallel from autoparallel.auto_bucketing import configure_inductor_for_autobucketing @@ -15,12 +18,78 @@ from torch.distributed.tensor.placement_types import Replicate, Shard from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.models.moe.moe import _run_experts_grouped_mm from torchtitan.tools.logging import logger +def create_functional_router_forward( + self: nn.Module, +) -> Callable: # TokenChoiceTopKRouter + def functional_router_forward( + x: torch.Tensor, gate_weight: torch.nn.Parameter, expert_bias: torch.Tensor + ): + # scores shape (bs*slen, num_experts) + scores = F.linear(x, gate_weight) + + # By default, sigmoid or softmax is performed in float32 to avoid loss explosion + if self.score_func == "sigmoid": + scores = torch.sigmoid(scores.to(torch.float32)) + elif self.score_func == "softmax": + scores = F.softmax(scores.to(torch.float32), dim=1) + else: + raise NotImplementedError(f"Unknown score function {self.score_func}") + + # top scores shape (bs*slen, top_k) + # NOTE: The expert_bias is only used for routing. The gating value + # top_scores is still derived from the original scores. + if expert_bias is not None: + _, selected_experts_indices = torch.topk( + scores + expert_bias, k=self.top_k, dim=1 + ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) + else: + top_scores, selected_experts_indices = torch.topk( + scores, k=self.top_k, dim=1 + ) + + # debug override: balanced round-robin routing + if self._debug_force_load_balance: + ( + selected_experts_indices, + top_scores, + ) = self._debug_force_load_balance_routing(scores) + + if self.route_norm: + denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 + top_scores = top_scores / denominator + top_scores = top_scores * self.route_scale + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + + return top_scores, selected_experts_indices, num_tokens_per_expert + + return functional_router_forward + + def _moe_forward( - x, router, expert_bias, reorderer, score_before_experts, experts, shared_experts + x: torch.Tensor, + router_gate_weight: torch.nn.Parameter, + expert_bias: Optional[torch.Tensor], + experts_w1: torch.Tensor, + experts_w3: torch.Tensor, + experts_w2: torch.Tensor, + shared_w1_weight: torch.Tensor, + shared_w3_weight: torch.Tensor, + shared_w2_weight: torch.Tensor, + functional_router_forward: Callable, + reorderer: nn.Module, # TokenReorderer ): bs, slen, dim = x.shape x = x.view(-1, dim) @@ -31,7 +100,7 @@ def _moe_forward( top_scores, selected_experts_indices, num_tokens_per_expert, - ) = router(x, expert_bias) + ) = functional_router_forward(x, router_gate_weight, expert_bias) num_tokens_per_expert_update = num_tokens_per_expert # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) @@ -56,26 +125,34 @@ def _moe_forward( # shape (bs*slen*top_k, dim) routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) - if score_before_experts: - routed_input = ( - routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) + # DSv3 score_before_experts is always False + # if score_before_experts: + # routed_input = ( + # routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + # ).to(x.dtype) # shape (bs*slen*top_k, dim) - routed_output = experts(routed_input, num_tokens_per_expert) + # routed_output = experts(routed_input, num_tokens_per_expert) + routed_output = _run_experts_grouped_mm( + experts_w1, experts_w2, experts_w3, routed_input, num_tokens_per_expert + ) # shared expert # Note: we execute the shared expert before scoring the output of the routed expert # to "implicitly" overlap the shared expert compute with token combine communication - if shared_experts is not None: - out = shared_experts(x) - else: - out = torch.zeros_like(x) - - if not score_before_experts: - routed_output = ( - routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) + # if shared_experts is not None: + # out = shared_experts(x) + _h1 = F.linear(x, shared_w1_weight) + _h3 = F.linear(x, shared_w3_weight) + out = F.linear(F.silu(_h1) * _h3, shared_w2_weight) + # else: + # out = torch.zeros_like(x) + + # DSv3 score_before_experts is False + # if not score_before_experts: + routed_output = ( + routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) out = out.scatter_add(dim=0, index=token_indices_experts_sorted, src=routed_output) out = out.reshape(bs, slen, dim) @@ -83,14 +160,19 @@ def _moe_forward( def moe_forward(self, x: torch.Tensor) -> torch.Tensor: + functional_router_forward = create_functional_router_forward(self.router) out, num_tokens_per_expert = _moe_forward( x, - self.router, + self.router.gate.weight, self.expert_bias, + self.experts.w1, + self.experts.w3, + self.experts.w2, + self.shared_experts.w1.weight, + self.shared_experts.w3.weight, + self.shared_experts.w2.weight, + functional_router_forward, self.reorderer, - self.score_before_experts, - self.experts, - self.shared_experts, ) # HOPs don't support buffer mutations, keep this outside # tokens_per_expert will be used to update the expert bias for load balancing. @@ -98,14 +180,24 @@ def moe_forward(self, x: torch.Tensor) -> torch.Tensor: # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- # first in the forward pass, and then in the backward pass. However, this has no # effect on the expert bias update thanks to the torch.sign() operator. - with torch.no_grad(): - self.tokens_per_expert.add_(num_tokens_per_expert) - with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) return out +def monkey_patch_checks(moe): + # causes data-dependent issue, hardcoded into monkey patch + assert not moe.score_before_experts + assert moe.router.gate.bias is None + assert moe.experts.use_grouped_mm + assert moe.shared_experts is not None + assert moe.shared_experts.w1.bias is None + assert moe.shared_experts.w2.bias is None + assert moe.shared_experts.w3.bias is None + assert not list(moe.reorderer.parameters()) + assert not list(moe.reorderer.buffers()) + + def monkey_patch_local_map_moe(model, world_mesh): """ TODO: fix HOPs not restoring the original signature. @@ -120,19 +212,21 @@ def monkey_patch_local_map_moe(model, world_mesh): _moe_forward = local_map( _moe_forward, out_placements=( - (Replicate(),), # (Shard(0),), - (Replicate(),), + (Replicate(),), # out: torch.Tensor + (Replicate(),), # num_tokens_per_expert_update: torch.Tensor ), in_placements=( - (Replicate(),), # (Shard(0),), - (Replicate(),), - (Replicate(),), - (Replicate(),), - (Replicate(),), - (Replicate(),), - (Replicate(),), - (Replicate(),), - (Replicate(),), + (Replicate(),), # x: torch.Tensor, + (Replicate(),), # router_gate_weight: torch.nn.Parameter, + (Replicate(),), # expert_bias: Optional[torch.Tensor], + (Replicate(),), # experts_w1: torch.Tensor, + (Replicate(),), # experts_w3: torch.Tensor, + (Replicate(),), # experts_w2: torch.Tensor, + (Replicate(),), # shared_w1: torch.Tensor, + (Replicate(),), # shared_w3: torch.Tensor, + (Replicate(),), # shared_w2: torch.Tensor, + None, # functional_router_forward: Callable, + None, # reorderer: TokenReorderer, ), redistribute_inputs=True, in_grad_placements=None, @@ -143,30 +237,7 @@ def monkey_patch_local_map_moe(model, world_mesh): if not block.moe_enabled: continue block.moe.forward = types.MethodType(moe_forward, block.moe) - - # torch.distributed.breakpoint() - # moe.forward = moe_forward - # moe._moe_forward = local_map( - # moe._moe_forward, - # out_placements=( - # (Replicate(),), # (Shard(0),), - # (Replicate(),), - # ), - # in_placements=( - # (Replicate(),), # (Shard(0),), - # (Replicate(),), - # (Replicate(),), - # (Replicate(),), - # (Replicate(),), - # (Replicate(),), - # (Replicate(),), - # (Replicate(),), - # (Replicate(),), - # ), - # redistribute_inputs=True, - # in_grad_placements=None, - # device_mesh=mesh, - # ) + monkey_patch_checks(block.moe) # Run workflow with: From ac0def979ad1e442afdb8dd4b80aa204c0b1f6e6 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 18 Nov 2025 16:23:26 -0800 Subject: [PATCH 31/49] move flex disables into experiment folder --- .../auto_parallel/deepseek_v3/__init__.py | 15 ++++++++++++++- torchtitan/models/deepseek_v3/__init__.py | 12 ++++++------ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py b/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py index ff220c496b..7aa7f98f9e 100644 --- a/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py @@ -6,6 +6,8 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. +import copy + from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing @@ -14,6 +16,7 @@ from torchtitan.hf_datasets.text_datasets import build_text_dataloader from torchtitan.models.deepseek_v3 import deepseekv3_args, DeepSeekV3Model +from torchtitan.models.deepseek_v3.model.args import DeepSeekV3ModelArgs from torchtitan.models.deepseek_v3.model.state_dict_adapter import ( DeepSeekV3StateDictAdapter, ) @@ -23,9 +26,19 @@ def get_train_spec() -> TrainSpec: + model_args = copy.deepcopy(deepseekv3_args) + + default_args = DeepSeekV3ModelArgs() + for config, args in model_args.items(): + if "flex_attn" in config: + continue + + use_flex_attn = (default_args.use_flex_attn,) + attn_mask_type = (default_args.attn_mask_type,) + return TrainSpec( model_cls=DeepSeekV3Model, - model_args=deepseekv3_args, + model_args=model_args, parallelize_fn=parallelize_deepseekv3, pipelining_fn=pipeline_llm, build_optimizers_fn=build_optimizers_with_moe_load_balancing, diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 969c4296e7..525bd96c13 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -97,8 +97,8 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - # use_flex_attn=True, - # attn_mask_type="block_causal", + use_flex_attn=True, + attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( vocab_size=102400, @@ -124,8 +124,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - # use_flex_attn=True, - # attn_mask_type="block_causal", + use_flex_attn=True, + attn_mask_type="block_causal", ), "671B": DeepSeekV3ModelArgs( vocab_size=129280, @@ -151,8 +151,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - # use_flex_attn=True, - # attn_mask_type="block_causal", + use_flex_attn=True, + attn_mask_type="block_causal", ), } From a24ef073274d83d3ab00c62bade07194522f8a26 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 18 Nov 2025 16:25:52 -0800 Subject: [PATCH 32/49] fix newline --- torchtitan/models/deepseek_v3/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 285df3bb54..3cf56eb1b2 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -291,8 +291,8 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.attention = Attention(model_args) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.moe_enabled = layer_id >= model_args.n_dense_layers + self.moe_enabled = layer_id >= model_args.n_dense_layers if self.moe_enabled: self.moe = MoE( model_args.moe_args, From da611e4ee20a8cb50fb27a25117be28babc0f063 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 18 Nov 2025 16:38:50 -0800 Subject: [PATCH 33/49] no longer necessary train.py changes --- torchtitan/train.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index b01180787a..3352ad9fde 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -130,7 +130,6 @@ def __init__(self, job_config: JobConfig): # build model (using meta init) model_args = self.train_spec.model_args[job_config.model.flavor] - model_cls = self.train_spec.model_cls # set the model args from training job configs model_args.update_from_config(job_config) self.model_args = model_args @@ -144,11 +143,9 @@ def __init__(self, job_config: JobConfig): ): model = self.train_spec.model_cls(model_args) - with torch.device("meta"): - model = model_cls(model_args) - # Build the collection of model converters. No-op if `model.converters` empty - model_converters = build_model_converters(job_config, parallel_dims) - model_converters.convert(model) + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) # metrics logging build_metrics_processor_fn = ( From 6cc8caaa042129885b30664cceb8515711aed75c Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 18 Nov 2025 16:40:42 -0800 Subject: [PATCH 34/49] restore comment --- torchtitan/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 3352ad9fde..5cfab998b2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -244,7 +244,7 @@ def __init__(self, job_config: JobConfig): # confirm that user will be able to view loss metrics on the console ensure_pp_loss_visible(parallel_dims, job_config, color) else: - # apply Autoparallel + # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) model.to_empty(device=init_device) From d54a6d4c92e063189c91b433e3b3d85d79dfb657 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 18 Nov 2025 17:01:52 -0800 Subject: [PATCH 35/49] temporarily extend hacky optimizer stuff to make dsv3 ap 1d run again --- torchtitan/components/optimizer.py | 44 +++++++++++++++--------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 49bc9bd534..24a69d975f 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -340,10 +340,30 @@ def build_optimizers_with_moe_load_balancing( ft_manager=ft_manager, ) + # AP friendly methods + def is_moe_block(block): + moe_enabled = getattr(block, "moe_enabled", False) + has_moe_submod = hasattr(block, "moe") # AP + return moe_enabled or has_moe_submod + + def should_manual_allreduce(tokens_per_expert_by_layer): + return not isinstance( + tokens_per_expert_by_layer, torch.distributed.tensor.DTensor + ) + + def get_transformer_blocks(model_part): + if isinstance(model_part.layers, nn.ModuleDict): + # regular torchtitan + blocks = model_part.layers.values() + else: + # TODO: fix autoparallel to preserve the module dict + blocks = model_part.layers.children() + return blocks + def _should_register_moe_balancing_hook(model_parts: list[nn.Module]) -> bool: for model_part in model_parts: - for transformer_block in model_part.layers.values(): - if transformer_block.moe_enabled: + for transformer_block in get_transformer_blocks(model_part): + if is_moe_block(transformer_block): # Assumption: load_balance_coeff is set universally on all moe blocks. return bool(transformer_block.moe.load_balance_coeff) return False @@ -360,26 +380,6 @@ def _update_expert_bias( parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None ) - # AP friendly methods - def is_moe_block(block): - moe_enabled = getattr(block, "moe_enabled", False) - has_moe_submod = hasattr(block, "moe") # AP - return moe_enabled or has_moe_submod - - def get_transformer_blocks(model_part): - if isinstance(model_part.layers, nn.ModuleDict): - # regular torchtitan - blocks = model_part.layers.values() - else: - # TODO: fix autoparallel to preserve the module dict - blocks = model_part.layers.children() - return blocks - - def should_manual_allreduce(tokens_per_expert_by_layer): - return not isinstance( - tokens_per_expert_by_layer, torch.distributed.tensor.DTensor - ) - # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. tokens_per_expert_list = [] From 2b1fb92b08470625660f7eb2e77eff29df32d8c5 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 21 Nov 2025 09:46:23 -0800 Subject: [PATCH 36/49] fix moduledict with AP https://github.com/meta-pytorch/autoparallel/pull/260 --- torchtitan/components/optimizer.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 24a69d975f..3c9db88b9e 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -351,18 +351,9 @@ def should_manual_allreduce(tokens_per_expert_by_layer): tokens_per_expert_by_layer, torch.distributed.tensor.DTensor ) - def get_transformer_blocks(model_part): - if isinstance(model_part.layers, nn.ModuleDict): - # regular torchtitan - blocks = model_part.layers.values() - else: - # TODO: fix autoparallel to preserve the module dict - blocks = model_part.layers.children() - return blocks - def _should_register_moe_balancing_hook(model_parts: list[nn.Module]) -> bool: for model_part in model_parts: - for transformer_block in get_transformer_blocks(model_part): + for transformer_block in model_part.layers.values(): if is_moe_block(transformer_block): # Assumption: load_balance_coeff is set universally on all moe blocks. return bool(transformer_block.moe.load_balance_coeff) @@ -384,8 +375,7 @@ def _update_expert_bias( # default compute stream. Need to assess if this is OK performance-wise. tokens_per_expert_list = [] for model_part in model_parts: - blocks = get_transformer_blocks(model_part) - for transformer_block in blocks: + for transformer_block in model_part.layers.values(): if not is_moe_block(transformer_block): continue if transformer_block.moe.load_balance_coeff is None: @@ -414,8 +404,7 @@ def _update_expert_bias( moe_layer_idx = 0 with torch.no_grad(): for model_part in model_parts: - blocks = get_transformer_blocks(model_part) - for transformer_block in blocks: + for transformer_block in model_part.layers.values(): if not is_moe_block(transformer_block): continue moe = transformer_block.moe From 68245d656c441a475fe33ea3fd53eeda008d8b17 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 21 Nov 2025 11:58:18 -0800 Subject: [PATCH 37/49] fix moe_enabled --- torchtitan/components/optimizer.py | 11 +++-------- .../deepseek_v3/parallelize_deepseekv3.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 3c9db88b9e..20233ba065 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -341,11 +341,6 @@ def build_optimizers_with_moe_load_balancing( ) # AP friendly methods - def is_moe_block(block): - moe_enabled = getattr(block, "moe_enabled", False) - has_moe_submod = hasattr(block, "moe") # AP - return moe_enabled or has_moe_submod - def should_manual_allreduce(tokens_per_expert_by_layer): return not isinstance( tokens_per_expert_by_layer, torch.distributed.tensor.DTensor @@ -354,7 +349,7 @@ def should_manual_allreduce(tokens_per_expert_by_layer): def _should_register_moe_balancing_hook(model_parts: list[nn.Module]) -> bool: for model_part in model_parts: for transformer_block in model_part.layers.values(): - if is_moe_block(transformer_block): + if transformer_block.moe_enabled: # Assumption: load_balance_coeff is set universally on all moe blocks. return bool(transformer_block.moe.load_balance_coeff) return False @@ -376,7 +371,7 @@ def _update_expert_bias( tokens_per_expert_list = [] for model_part in model_parts: for transformer_block in model_part.layers.values(): - if not is_moe_block(transformer_block): + if not transformer_block.moe_enabled: continue if transformer_block.moe.load_balance_coeff is None: return @@ -405,7 +400,7 @@ def _update_expert_bias( with torch.no_grad(): for model_part in model_parts: for transformer_block in model_part.layers.values(): - if not is_moe_block(transformer_block): + if not transformer_block.moe_enabled: continue moe = transformer_block.moe diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py index 66a5b71354..89092dec64 100644 --- a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py @@ -240,6 +240,14 @@ def monkey_patch_local_map_moe(model, world_mesh): monkey_patch_checks(block.moe) +# TODO: Autoparallel should transparently wrap the original nn.Module +# but I don't know how to do that. +def set_torchtitan_fields(orig, new): + assert isinstance(new.layers, torch.nn.ModuleDict) + for block in new.layers.values(): + block.moe_enabled = hasattr(block, "moe") + + # Run workflow with: # CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel def parallelize_deepseekv3( @@ -352,6 +360,8 @@ def input_fn(): logger.info(f"AutoParallel took {t1 - t0} seconds") parallel_mod = autop.apply_placement(sharding_placement) + set_torchtitan_fields(model, parallel_mod) + if loss_parallel_enabled: # current PyTorch's implementation of loss parallel assumes From e592e22a784d1426df17534a4f8f3ecd33f34475 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 21 Nov 2025 12:01:07 -0800 Subject: [PATCH 38/49] lint --- torchtitan/components/optimizer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 20233ba065..8eec355b69 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -340,7 +340,6 @@ def build_optimizers_with_moe_load_balancing( ft_manager=ft_manager, ) - # AP friendly methods def should_manual_allreduce(tokens_per_expert_by_layer): return not isinstance( tokens_per_expert_by_layer, torch.distributed.tensor.DTensor @@ -365,7 +364,6 @@ def _update_expert_bias( dp_cp_mesh = ( parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None ) - # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. tokens_per_expert_list = [] From 737ad2cd3ee29f48f66dfff0611ef5679159daa6 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 21 Nov 2025 13:01:44 -0800 Subject: [PATCH 39/49] job config --- torchtitan/config/job_config.py | 5 --- .../experiments/auto_parallel/job_config.py | 44 +++++++++++++++++++ 2 files changed, 44 insertions(+), 5 deletions(-) create mode 100644 torchtitan/experiments/auto_parallel/job_config.py diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 5918d2918d..95588d2c3b 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -865,11 +865,6 @@ class Experimental: needs to ensure that the path can be imported. """ - # "aten" (default), "inductor", "none" - comms_bucket_reorder_strategy: str = "aten" - - autop_force_bf16: bool = False - @dataclass class Validation: diff --git a/torchtitan/experiments/auto_parallel/job_config.py b/torchtitan/experiments/auto_parallel/job_config.py new file mode 100644 index 0000000000..5f93d16bfb --- /dev/null +++ b/torchtitan/experiments/auto_parallel/job_config.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +""" +Use --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config +""" + + +@dataclass +class Experimental: + custom_import: str = "" + """ + This option enables the importation of external modules. + Currently, it only supports dotted import modules (e.g., some_package.model_x). + It is the user's responsibility to ensure that the specified path can be + successfully imported. One method to achieve this, you can place your module + inside the ``torchtitan/torchtitan`` folder and execute ``pip install -e .`` to + make it available for import. + """ + + custom_args_module: str = "" + """ + DEPRECATED (moved to Job.custom_config_module). Will be removed soon. + + This option allows users to extend TorchTitan's existing JobConfig by extending + a user defined JobConfig dataclass. Similar to ``--experimental.custom_import``, the user + needs to ensure that the path can be imported. + """ + + # "aten" (default), "inductor", "none" + comms_bucket_reorder_strategy: str = "aten" + + autop_force_bf16: bool = False + + +@dataclass +class JobConfig: + experimental: Experimental = field(default_factory=Experimental) From 64e605094f1fb98dfd82d2ffe8f2ddaea1d7e136 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 21 Nov 2025 14:54:45 -0800 Subject: [PATCH 40/49] remove MAST specific profiling logs --- torchtitan/tools/profiling.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 7edfe66979..710397568c 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -16,9 +16,6 @@ # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 -PERFETTO_UI_ROOT_URL = ( - "https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html" -) @contextlib.contextmanager @@ -50,25 +47,13 @@ def trace_handler(prof): logger.info(f"Dumping profiler traces at step {prof.step_num}") begin = time.monotonic() + output_file = os.path.join(curr_trace_dir, f"rank{rank}_trace.json") prof.export_chrome_trace(output_file) - log_str = f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" - # not directly landable on upstream titan, - # but conveniently prints the internal url for perfetto on manifold for mast jobs - manifold_mount_prefix = "/mnt/mffuse/" - if output_file.find(manifold_mount_prefix) == 0: - manifold_path = os.path.join( - "torchtrain_datasets/tree", - output_file.split(manifold_mount_prefix)[1], - ) - perfetto_url = ( - PERFETTO_UI_ROOT_URL - + "#!/?url=https://interncache-all.fbcdn.net/manifold/" - + manifold_path - ) - log_str += f": {perfetto_url}" - logger.info(log_str) + logger.info( + f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" + ) logger.info(f"Profiling active. Traces will be saved at {trace_dir}") From de6dca6ed76f502f1fafc35e4987352fb11450c5 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 21 Nov 2025 15:36:53 -0800 Subject: [PATCH 41/49] update readme --- torchtitan/experiments/README.md | 1 + torchtitan/experiments/auto_parallel/README.md | 14 +++++++++----- .../experiments/auto_parallel/llama3/__init__.py | 4 ---- torchtitan/tools/profiling.py | 1 - 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 08dc692bf9..79bd6ffeca 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -32,3 +32,4 @@ We provide this `experiments/` folder to host experiments that add significant v | [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | | [transformers_backend](./transformers_backend/) | [![Transformers backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | +| [auto_parallel](./auto_parallel/) | TBA | [@wconstab](https://github.com/wconstab) | [@xmfan](https://github.com/xmfan) | diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md index 7e112329b9..68c7afb09e 100644 --- a/torchtitan/experiments/auto_parallel/README.md +++ b/torchtitan/experiments/auto_parallel/README.md @@ -1,11 +1,15 @@ ## Auto Parallel -requires installing git@github.com:pytorch-labs/autoparallel.git +### Overview -`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` +The Auto Parallel experiment integrates PyTorch's AutoParallel framework with TorchTitan to automatically optimize distributed training parallelism strategies given a device mesh. Instead of manually configuring parallelism layouts, AutoParallel uses cost-based analysis to determine optimal sharding placements for model parameters, activations, and gradients. -Use simplefsdp's autobucketing pass: +### Usage -`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --experimental.enable_simplefsdp_passes --compile.enable` +Requires installing [git@github.com:pytorch-labs/autoparallel.git](https://github.com/meta-pytorch/autoparallel) -(or llama3-8b.toml) +### Llama3 +`CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.llama3 --parallelism.tensor_parallel_degree 4 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config` + +### DeepSeekv3 [WIP] +`CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config` diff --git a/torchtitan/experiments/auto_parallel/llama3/__init__.py b/torchtitan/experiments/auto_parallel/llama3/__init__.py index f9e61ddd7e..ea38ac631a 100644 --- a/torchtitan/experiments/auto_parallel/llama3/__init__.py +++ b/torchtitan/experiments/auto_parallel/llama3/__init__.py @@ -21,10 +21,6 @@ from .parallelize_llama import parallelize_llama -# CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml -# ./run_train.sh --model.name auto_parallel.llama3 --parallelism.tensor_parallel_degree 4 - - def get_train_spec() -> TrainSpec: return TrainSpec( model_cls=Transformer, diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 710397568c..f398dba9b5 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -49,7 +49,6 @@ def trace_handler(prof): begin = time.monotonic() output_file = os.path.join(curr_trace_dir, f"rank{rank}_trace.json") - prof.export_chrome_trace(output_file) logger.info( f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" From fe0b6cc8857ac81373c6baffdc5bb27f393e48c7 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 21 Nov 2025 15:43:49 -0800 Subject: [PATCH 42/49] format readme --- torchtitan/experiments/auto_parallel/README.md | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md index 68c7afb09e..55dcc3c5e5 100644 --- a/torchtitan/experiments/auto_parallel/README.md +++ b/torchtitan/experiments/auto_parallel/README.md @@ -4,12 +4,16 @@ The Auto Parallel experiment integrates PyTorch's AutoParallel framework with TorchTitan to automatically optimize distributed training parallelism strategies given a device mesh. Instead of manually configuring parallelism layouts, AutoParallel uses cost-based analysis to determine optimal sharding placements for model parameters, activations, and gradients. -### Usage +### Requirements -Requires installing [git@github.com:pytorch-labs/autoparallel.git](https://github.com/meta-pytorch/autoparallel) +Requires installing [git@github.com:meta-pytorch/autoparallel.git](https://github.com/meta-pytorch/autoparallel) + +### Single Node + +**Llama3** -### Llama3 `CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.llama3 --parallelism.tensor_parallel_degree 4 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config` -### DeepSeekv3 [WIP] +**DeepSeekv3** + `CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config` From 2b37f30a91e9be028253227a87fbb96fedc8ee65 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 21 Nov 2025 18:13:30 -0800 Subject: [PATCH 43/49] comments --- torchtitan/components/optimizer.py | 9 +++------ .../experiments/auto_parallel/job_config.py | 19 ------------------- 2 files changed, 3 insertions(+), 25 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 8eec355b69..87dc0f0e0b 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -340,11 +340,6 @@ def build_optimizers_with_moe_load_balancing( ft_manager=ft_manager, ) - def should_manual_allreduce(tokens_per_expert_by_layer): - return not isinstance( - tokens_per_expert_by_layer, torch.distributed.tensor.DTensor - ) - def _should_register_moe_balancing_hook(model_parts: list[nn.Module]) -> bool: for model_part in model_parts: for transformer_block in model_part.layers.values(): @@ -385,7 +380,9 @@ def _update_expert_bias( tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list) if dp_cp_mesh is not None: - if should_manual_allreduce(tokens_per_expert_by_layer): + if isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor): + tokens_per_expert_by_layer = tokens_per_expert_by_layer.full_tensor() + else: # Perform single all-reduce to get global statistics across all processes pg = dp_cp_mesh.get_group() torch.distributed.all_reduce( diff --git a/torchtitan/experiments/auto_parallel/job_config.py b/torchtitan/experiments/auto_parallel/job_config.py index 5f93d16bfb..c880cadb31 100644 --- a/torchtitan/experiments/auto_parallel/job_config.py +++ b/torchtitan/experiments/auto_parallel/job_config.py @@ -14,25 +14,6 @@ @dataclass class Experimental: - custom_import: str = "" - """ - This option enables the importation of external modules. - Currently, it only supports dotted import modules (e.g., some_package.model_x). - It is the user's responsibility to ensure that the specified path can be - successfully imported. One method to achieve this, you can place your module - inside the ``torchtitan/torchtitan`` folder and execute ``pip install -e .`` to - make it available for import. - """ - - custom_args_module: str = "" - """ - DEPRECATED (moved to Job.custom_config_module). Will be removed soon. - - This option allows users to extend TorchTitan's existing JobConfig by extending - a user defined JobConfig dataclass. Similar to ``--experimental.custom_import``, the user - needs to ensure that the path can be imported. - """ - # "aten" (default), "inductor", "none" comms_bucket_reorder_strategy: str = "aten" From bc18d875a3360d83b00ccd08dbaf1a69eacfe058 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 21 Nov 2025 18:19:24 -0800 Subject: [PATCH 44/49] manual redistribute --- torchtitan/components/optimizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 87dc0f0e0b..2d7c6f7901 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -381,7 +381,9 @@ def _update_expert_bias( if dp_cp_mesh is not None: if isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor): - tokens_per_expert_by_layer = tokens_per_expert_by_layer.full_tensor() + tokens_per_expert_by_layer = tokens_per_expert_by_layer.redistribute( + placements=[Replicate()] * dp_cp_mesh.ndim + ) else: # Perform single all-reduce to get global statistics across all processes pg = dp_cp_mesh.get_group() From 5fdf737013f472d5f98d13251969b265f9364cb2 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 21 Nov 2025 18:26:02 -0800 Subject: [PATCH 45/49] imports --- torchtitan/components/optimizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 2d7c6f7901..7746d7f89a 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -16,6 +16,7 @@ StateDictOptions, ) from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.tensor import Replicate from torch.optim import Optimizer from torchtitan.components.ft import FTManager, has_torchft From c1a307f805146e92e744314c9e9b27aafaf51e06 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Sun, 23 Nov 2025 13:32:26 -0800 Subject: [PATCH 46/49] mesh --- torchtitan/components/optimizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 7746d7f89a..80557366da 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -383,7 +383,8 @@ def _update_expert_bias( if dp_cp_mesh is not None: if isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor): tokens_per_expert_by_layer = tokens_per_expert_by_layer.redistribute( - placements=[Replicate()] * dp_cp_mesh.ndim + placements=[Replicate()] + * tokens_per_expert_by_layer.device_mesh.ndim ) else: # Perform single all-reduce to get global statistics across all processes From c480cd14309b0c5181c1a8253b6794b576361442 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Sun, 23 Nov 2025 20:57:23 -0800 Subject: [PATCH 47/49] no flex --- torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py b/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py index 7aa7f98f9e..b90583c86b 100644 --- a/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py @@ -33,8 +33,8 @@ def get_train_spec() -> TrainSpec: if "flex_attn" in config: continue - use_flex_attn = (default_args.use_flex_attn,) - attn_mask_type = (default_args.attn_mask_type,) + args.attn_type = default_args.attn_type + args.attn_mask_type = default_args.attn_mask_type return TrainSpec( model_cls=DeepSeekV3Model, From aa739f6180c94c8c5a05568b9bf2ed06b4785097 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Sun, 23 Nov 2025 21:20:09 -0800 Subject: [PATCH 48/49] update with new moe --- .../deepseek_v3/parallelize_deepseekv3.py | 49 +++++++++++-------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py index 89092dec64..fc278cfabe 100644 --- a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py @@ -90,11 +90,12 @@ def _moe_forward( shared_w2_weight: torch.Tensor, functional_router_forward: Callable, reorderer: nn.Module, # TokenReorderer + top_k: int, ): bs, slen, dim = x.shape x = x.view(-1, dim) - # top_scores and selected_experts_indices shape (bs*slen*top_k,) + # top_scores and selected_experts_indices shape (bs*slen, top_k) # num_tokens_per_expert shape (num_experts,) ( top_scores, @@ -103,7 +104,7 @@ def _moe_forward( ) = functional_router_forward(x, router_gate_weight, expert_bias) num_tokens_per_expert_update = num_tokens_per_expert - # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) + # top_scores_experts_sorted and token_indices_experts_sorted shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) # NOTE: the reason we need to compute num_tokens_per_expert again is: # 1st computation in router is to update self.tokens_per_expert @@ -118,12 +119,7 @@ def _moe_forward( ) = reorderer(top_scores, selected_experts_indices) # shape (bs*slen*top_k, dim) - token_indices_experts_sorted = token_indices_experts_sorted.reshape(-1, 1).expand( - -1, dim - ) - - # shape (bs*slen*top_k, dim) - routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) + routed_input = x[token_indices_experts_sorted // top_k] # DSv3 score_before_experts is always False # if score_before_experts: @@ -137,26 +133,37 @@ def _moe_forward( experts_w1, experts_w2, experts_w3, routed_input, num_tokens_per_expert ) - # shared expert + # always has shared expert # Note: we execute the shared expert before scoring the output of the routed expert # to "implicitly" overlap the shared expert compute with token combine communication - # if shared_experts is not None: - # out = shared_experts(x) _h1 = F.linear(x, shared_w1_weight) _h3 = F.linear(x, shared_w3_weight) out = F.linear(F.silu(_h1) * _h3, shared_w2_weight) - # else: - # out = torch.zeros_like(x) + # Unsort routed outputs + routed_output_unsorted = torch.zeros( + (bs * slen * top_k, dim), + dtype=routed_output.dtype, + device=routed_output.device, + ) + routed_output_unsorted[token_indices_experts_sorted] = routed_output + routed_output_unsorted = routed_output_unsorted.reshape(-1, top_k, dim) # DSv3 score_before_experts is False - # if not score_before_experts: - routed_output = ( - routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) + # if not self.score_before_experts: + out_experts = ( + torch.bmm( + top_scores.reshape(-1, 1, top_k), + routed_output_unsorted.float(), + ) + .to(x.dtype) + .squeeze(1) + ) + # else: + # out_experts = routed_output_unsorted.sum(dim=1) - out = out.scatter_add(dim=0, index=token_indices_experts_sorted, src=routed_output) - out = out.reshape(bs, slen, dim) - return out, num_tokens_per_expert_update + # always has shared experts + # if out is None: + return (out + out_experts).reshape(bs, slen, dim), num_tokens_per_expert_update def moe_forward(self, x: torch.Tensor) -> torch.Tensor: @@ -173,6 +180,7 @@ def moe_forward(self, x: torch.Tensor) -> torch.Tensor: self.shared_experts.w2.weight, functional_router_forward, self.reorderer, + self.router.top_k, ) # HOPs don't support buffer mutations, keep this outside # tokens_per_expert will be used to update the expert bias for load balancing. @@ -227,6 +235,7 @@ def monkey_patch_local_map_moe(model, world_mesh): (Replicate(),), # shared_w2: torch.Tensor, None, # functional_router_forward: Callable, None, # reorderer: TokenReorderer, + None, # top_k ), redistribute_inputs=True, in_grad_placements=None, From f03fe9e9fc78f4b980ad3a73ce8be842ec1891a1 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Mon, 24 Nov 2025 15:44:05 -0800 Subject: [PATCH 49/49] remove transformers_backend --- torchtitan/experiments/README.md | 1 - torchtitan/experiments/__init__.py | 1 - 2 files changed, 2 deletions(-) diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 5874f7401a..aa93628656 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -31,6 +31,5 @@ We provide this `experiments/` folder to host experiments that add significant v | [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) | | [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | -| [transformers_backend](./transformers_backend/) | [![Transformers backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | | [transformers_modeling_backend](./transformers_modeling_backend/) | [![Transformers modeling backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | | [auto_parallel](./auto_parallel/) | TBA | [@wconstab](https://github.com/wconstab) | [@xmfan](https://github.com/xmfan) | diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 3c0cfc1939..7e2c442103 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -12,7 +12,6 @@ "vlm", "compiler_toolkit.deepseek_v3", "compiler_toolkit.llama3", - "transformers_backend", "transformers_modeling_backend", "auto_parallel.llama3", "auto_parallel.deepseek_v3",