Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def estimate_memory(job_config: JobConfig):
job_config.experimental.enable_compiled_autograd = False

parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.training.data_parallel_type,
)

device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
Expand Down
40 changes: 32 additions & 8 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 1",
"--training.data_parallel_shard_degree 1",
],
],
"PP 1D test 1f1b",
Expand All @@ -172,7 +172,7 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 1",
"--training.data_parallel_shard_degree 1",
],
],
"PP 1D test gpipe",
Expand All @@ -187,7 +187,7 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
],
],
"PP+DP 1f1b 2D test",
Expand All @@ -201,7 +201,7 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
],
],
"PP+DP gpipe 2D test",
Expand All @@ -227,15 +227,15 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
],
[
"--training.steps 20",
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
],
],
Expand All @@ -249,7 +249,7 @@ def build_test_list():
[
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
"--training.compile",
],
Expand Down Expand Up @@ -285,13 +285,37 @@ def build_test_list():
OverrideDefinitions(
[
[
"--training.data_parallel_type ddp",
"--training.data_parallel_shard_degree=1",
"--training.data_parallel_replicate_degree=4",
]
],
"DDP",
"ddp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--training.data_parallel_replicate_degree=2",
]
],
"HSDP",
"hsdp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--training.data_parallel_replicate_degree=2",
"--training.tensor_parallel_degree=2",
]
],
"HSDP+TP",
"hsdp+tp",
ngpu=8,
),
OverrideDefinitions(
[
[
Expand Down
34 changes: 26 additions & 8 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,34 @@ def __init__(self):
help="How many train steps to run",
)
self.parser.add_argument(
"--training.data_parallel_degree",
"--training.data_parallel_replicate_degree",
type=int,
default=1,
help="""
The `data_parallel_replicate_degree` argument specifies the degree of
data parallelism for weight replication. When this value is greater
than 1, weights will be replicated across `data_parallel_replicate_degree`
ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
parallelism method used is DDP (Distributed Data Parallelism).
1 means disabled.""",
)
self.parser.add_argument(
"--training.data_parallel_shard_degree",
type=int,
default=-1,
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.",
help="""
The `data_parallel_shard_degree` argument specifies the degree of data
parallelism for weight sharding. When this value is greater than 1, weights
will be sharded across `data_parallel_shard_degree` ranks. If
`data_parallel_replicate_degree` is also greater than 1, the parallelism
method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
parallelism method used is FSDP (Fully Sharded Data Parallelism).

-1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
only one of `data_parallel_replicate_degree` and `data_parallel_shard_degree`
can be negative.
1 means disabled.""",
)
self.parser.add_argument(
"--training.tensor_parallel_degree",
Expand Down Expand Up @@ -297,12 +321,6 @@ def __init__(self):
The default value will be the number of pipeline stages, if unspecified.
""",
)
self.parser.add_argument(
"--training.data_parallel_type",
type=str,
default="fsdp",
help="Data parallelism type. TorchTitan currently supports FSDP and DDP.",
)
self.parser.add_argument(
"--experimental.enable_compiled_autograd",
action="store_true",
Expand Down
63 changes: 48 additions & 15 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,78 @@

@dataclass
class ParallelDims:
dp: int
dp_replicate: int
dp_shard: int
tp: int
pp: int
world_size: int
enable_loss_parallel: bool
dp_type: str

def __post_init__(self):
self.dp_type = self.dp_type.lower()
self._validate()

def _validate(self):
dp, tp, pp = self.dp, self.tp, self.pp
if dp == -1:
self.dp = dp = self.world_size // (tp * pp)
assert dp >= 1, dp
dp_replicate, dp_shard, tp, pp = (
self.dp_replicate,
self.dp_shard,
self.tp,
self.pp,
)
for d in (dp_replicate, tp, pp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."

dp = dp_replicate * dp_shard
if dp < 0:
dp = self.world_size // (tp * pp)
self.dp_shard = dp_shard = dp // dp_replicate

assert dp_replicate >= 1
assert dp_shard >= 1
assert tp >= 1, tp
assert pp >= 1, pp
assert (
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
assert self.dp_type in ("fsdp", "ddp")
assert dp_replicate * dp_shard * tp * pp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
[self.pp, self.dp_replicate, self.dp_shard, self.tp],
["pp", "dp_replicate", "dp_shard", "tp"],
strict=True,
):
if d > 1:
dims.append(d)
names.append(name)
if (name == "dp_replicate" and self.dp_shard == 1) or (
name == "dp_shard" and self.dp_replicate == 1
):
names.append("dp")
Copy link
Contributor

@wconstab wconstab Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe don't change, but it is not obvious if we need to add 'dp'. What is the downside of leaving original names? 'dp_replicate' is clearer than 'dp' if someone is looking at PG names and wondering what parallelism is used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add dp because for loss computation and dataloader, dp is required, whether dp is dp_replicate + dp_shard (HSDP) or dp_shard (FSDP). These two components care only about dp.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense that the device mesh will have the axis:

  1. dp when DDP or FSDP is used;
  2. dp_shard and dp_replicate as well as their flattened dp when HSDP is used.

One corner case is self.world_size == tp * pp where two dp will be added to names.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I don't think so, if both dp_replicate and dp_shard are 1,line 59, if d > 1 won't be true. So we will never add dp mesh.

else:
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
return init_device_mesh(device_type, dims, mesh_dim_names=names)
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
# Create all the submesh here to ensure all required process groups are
# initialized
if self.dp_replicate > 1 and self.dp_shard > 1:
mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp")
Comment on lines +73 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask when and why do we need the flattened "dp" mesh? Is it just for HSDP?

Copy link
Contributor Author

@fegin fegin Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, DP is needed for dataloader and loss computation. It's easier for dataloader and loss computation to only know DP. So I ensure there always exist DP mesh.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh makes sense!

return mesh

@property
def dp_enabled(self):
return self.dp > 1
return self.dp_replicate > 1 or self.dp_shard > 1

@property
def dp_replicate_enabled(self):
return self.dp_replicate > 1

@property
def dp_shard_enabled(self):
return self.dp_shard > 1

@property
def tp_enabled(self):
Expand Down
14 changes: 9 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ def parallelize_llama(
apply_compile(model)

if parallel_dims.dp_enabled:
if parallel_dims.dp_type == "fsdp":
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
if parallel_dims.dp_shard_enabled:
if parallel_dims.dp_replicate_enabled:
dp_mesh = world_mesh["dp_replicate", "dp_shard"]
else:
dp_mesh = world_mesh["dp"]

apply_fsdp(
model,
Expand All @@ -87,6 +89,10 @@ def parallelize_llama(
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
)
if parallel_dims.dp_replicate_enabled:
logger.info("Applied HSDP to the model")
else:
logger.info("Applied FSDP to the model")
else:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
Expand Down Expand Up @@ -322,8 +328,6 @@ def apply_fsdp(
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)

logger.info("Applied FSDP to the model")


def apply_ddp(
model: nn.Module,
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def main(job_config: JobConfig):
# init distributed
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.training.data_parallel_type,
)
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
Expand Down
3 changes: 2 additions & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ seq_len = 2048
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama2_13b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ seq_len = 4096
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4"
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama2_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ seq_len = 4096
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 8 # 8-way TP
compile = false
dataset = "c4"
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama2_7b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ seq_len = 2048
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1 # dp-only would be sufficient for 7B
compile = false
dataset = "c4"
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama3_405b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ seq_len = 8192
warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 3000
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 8 # 8-way TP
compile = true
dataset = "c4"
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ seq_len = 8192
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 8 # 8-way TP
compile = false
dataset = "c4"
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ seq_len = 8192
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4"
Expand Down