Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
3ccd12c
[WIP] Integrate autoparallel into torchtitan
wconstab Jun 13, 2025
e6d2caf
Autoparallel support for DP-only, DP+TP, or TP-only
wconstab Jun 27, 2025
68476b3
Update CLI inductor configs for bucketing/reordering
wconstab Jul 25, 2025
9ee9f75
add back llama3_autoparallel_init_fn
wconstab Jul 25, 2025
f6e4099
Track API change from new AOTAutograd interface
ezyang Jul 28, 2025
4d7ee8a
Support forcing the model into bf16 for perf debugging
wconstab Jul 28, 2025
b801d0b
Integrate MixedPrecision with AutoParallel and fix example_inputs
wconstab Jul 29, 2025
b099cf9
Use in-place compile API
ezyang Jul 29, 2025
b3587d9
Fix bucketing pass configs
wconstab Jul 29, 2025
42c2c07
Support both eager and autoparallel init based on model.name
wconstab Jul 30, 2025
d93845e
Remove llama3 init weights hack
wconstab Aug 6, 2025
60f5f11
Print profiling manifold url
wconstab Aug 7, 2025
6c782eb
Support new compile API from autoparallel PR #77
wconstab Aug 8, 2025
4712163
Fix bucket sizes for AutoParallel 1D (#1545)
fmassa Aug 8, 2025
3f04d22
Add support for loss parallel (#1546)
fmassa Aug 10, 2025
8e50870
Add config for running simple-fsdp bucketing/reordering passes
wconstab Aug 18, 2025
91c5639
Hook up deepseekv3_auto_parallel
wconstab Aug 19, 2025
1233902
[dsv3] patch graph break fix, works up until sharding rules
xmfan Aug 19, 2025
4f8677b
update simplefsdp pass config
ruisizhang123 Aug 21, 2025
714cc5b
[dsv3] disable MoE while we fix local_map, works up until optimizer
xmfan Aug 22, 2025
45647b3
Merge branch 'main' into whc/merge_autoparallel
wconstab Aug 28, 2025
bfa9f7f
tweak ds3 model.py to reflect main branch for DS3 baseline can run (#…
bdhirsh Sep 5, 2025
75fb2eb
add simplefsdp's autobucketing pass entry (#1658)
ruisizhang123 Sep 6, 2025
8769396
[dsv3] 1D AP w/ local_map
xmfan Sep 11, 2025
db22479
[dsv3] Turn off Flex for AP
xmfan Sep 17, 2025
87ef4e0
Merge branch 'main' into autoparallel
xmfan Oct 27, 2025
9dc0bd8
Update to new model registration API
xmfan Oct 27, 2025
c6e25bd
Whc/knobs (#1994)
wconstab Nov 6, 2025
26410e8
Merge remote-tracking branch 'origin/main' into autoparallel
xmfan Nov 18, 2025
e6ea814
lint
xmfan Nov 18, 2025
7abede8
undo moe patching
xmfan Nov 18, 2025
d2e76b7
move inductor config into experiment folders
xmfan Nov 18, 2025
472b4ad
fix local_map moe patch
xmfan Nov 19, 2025
ac0def9
move flex disables into experiment folder
xmfan Nov 19, 2025
a24ef07
fix newline
xmfan Nov 19, 2025
da611e4
no longer necessary train.py changes
xmfan Nov 19, 2025
6cc8caa
restore comment
xmfan Nov 19, 2025
d54a6d4
temporarily extend hacky optimizer stuff to make dsv3 ap 1d run again
xmfan Nov 19, 2025
acd9588
Merge remote-tracking branch 'origin/main' into autoparallel
xmfan Nov 21, 2025
2b1fb92
fix moduledict with AP https://github.com/meta-pytorch/autoparallel/p…
xmfan Nov 21, 2025
68245d6
fix moe_enabled
xmfan Nov 21, 2025
e592e22
lint
xmfan Nov 21, 2025
737ad2c
job config
xmfan Nov 21, 2025
64e6050
remove MAST specific profiling logs
xmfan Nov 21, 2025
de6dca6
update readme
xmfan Nov 21, 2025
fe0b6cc
format readme
xmfan Nov 21, 2025
2b37f30
comments
xmfan Nov 22, 2025
bc18d87
manual redistribute
xmfan Nov 22, 2025
5fdf737
imports
xmfan Nov 22, 2025
c1a307f
mesh
xmfan Nov 23, 2025
58d5349
Merge remote-tracking branch 'origin/main' into autoparallel
xmfan Nov 23, 2025
c480cd1
no flex
xmfan Nov 24, 2025
aa739f6
update with new moe
xmfan Nov 24, 2025
f03fe9e
remove transformers_backend
xmfan Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
StateDictOptions,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.tensor import Replicate
from torch.optim import Optimizer

from torchtitan.components.ft import FTManager, has_torchft
Expand Down Expand Up @@ -380,11 +381,19 @@ def _update_expert_bias(
tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list)

if dp_cp_mesh is not None:
# Perform single all-reduce to get global statistics across all processes
pg = dp_cp_mesh.get_group()
torch.distributed.all_reduce(
tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM
)
if isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor):
tokens_per_expert_by_layer = tokens_per_expert_by_layer.redistribute(
placements=[Replicate()]
* tokens_per_expert_by_layer.device_mesh.ndim
)
else:
# Perform single all-reduce to get global statistics across all processes
pg = dp_cp_mesh.get_group()
torch.distributed.all_reduce(
tokens_per_expert_by_layer,
group=pg,
op=torch.distributed.ReduceOp.SUM,
)

moe_layer_idx = 0
with torch.no_grad():
Expand Down
1 change: 1 addition & 0 deletions torchtitan/experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ We provide this `experiments/` folder to host experiments that add significant v
| [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) |
| [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) |
| [transformers_modeling_backend](./transformers_modeling_backend/) | [![Transformers modeling backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) |
| [auto_parallel](./auto_parallel/) | TBA | [@wconstab](https://github.com/wconstab) | [@xmfan](https://github.com/xmfan) |
2 changes: 2 additions & 0 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@
"compiler_toolkit.deepseek_v3",
"compiler_toolkit.llama3",
"transformers_modeling_backend",
"auto_parallel.llama3",
"auto_parallel.deepseek_v3",
]
)
19 changes: 19 additions & 0 deletions torchtitan/experiments/auto_parallel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
## Auto Parallel

### Overview

The Auto Parallel experiment integrates PyTorch's AutoParallel framework with TorchTitan to automatically optimize distributed training parallelism strategies given a device mesh. Instead of manually configuring parallelism layouts, AutoParallel uses cost-based analysis to determine optimal sharding placements for model parameters, activations, and gradients.

### Requirements

Requires installing [git@github.com:meta-pytorch/autoparallel.git](https://github.com/meta-pytorch/autoparallel)

### Single Node

**Llama3**

`CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.llama3 --parallelism.tensor_parallel_degree 4 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config`

**DeepSeekv3**

`CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config`
50 changes: 50 additions & 0 deletions torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

import copy

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.distributed.pipeline_parallel import pipeline_llm
from torchtitan.hf_datasets.text_datasets import build_text_dataloader

from torchtitan.models.deepseek_v3 import deepseekv3_args, DeepSeekV3Model
from torchtitan.models.deepseek_v3.model.args import DeepSeekV3ModelArgs
from torchtitan.models.deepseek_v3.model.state_dict_adapter import (
DeepSeekV3StateDictAdapter,
)
from torchtitan.protocols.train_spec import TrainSpec

from .parallelize_deepseekv3 import parallelize_deepseekv3


def get_train_spec() -> TrainSpec:
model_args = copy.deepcopy(deepseekv3_args)

default_args = DeepSeekV3ModelArgs()
for config, args in model_args.items():
if "flex_attn" in config:
continue

args.attn_type = default_args.attn_type
args.attn_mask_type = default_args.attn_mask_type

return TrainSpec(
model_cls=DeepSeekV3Model,
model_args=model_args,
parallelize_fn=parallelize_deepseekv3,
pipelining_fn=pipeline_llm,
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_text_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
state_dict_adapter=DeepSeekV3StateDictAdapter,
)
Loading
Loading