From 85c607a3326789269487f6f843e9f5fb36b41b0f Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 13 Aug 2024 00:13:48 -0700 Subject: [PATCH 1/7] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 8 +++++- torchtitan/parallelisms/parallel_dims.py | 29 ++++++++++++++++---- torchtitan/parallelisms/parallelize_llama.py | 10 +++++-- train.py | 17 +++++++++++- 4 files changed, 53 insertions(+), 11 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2bc37bfbf0..b1ac76bff4 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -316,7 +316,13 @@ def __init__(self): "--training.data_parallel_type", type=str, default="fsdp", - help="Data parallelism type. TorchTitan currently supports FSDP and DDP.", + help="Data parallelism type. TorchTitan currently supports FSDP, HSDP, and DDP.", + ) + self.parser.add_argument( + "--training.data_parallel_replicate_degree", + type=int, + default=1, + help="When data_parallel_type is HSDP, data parallelism has 2 different shardings: replicate and shard. This argument specifies the degree of replicate.", ) self.parser.add_argument( "--experimental.enable_compiled_autograd", diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 22c114edae..d5b23e4248 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -18,23 +18,26 @@ class ParallelDims: pp: int world_size: int enable_loss_parallel: bool - dp_type: str + dp_type: str = "fsdp" + dp_replicate: int = 1 # Only used when dp_type is hsdp def __post_init__(self): self.dp_type = self.dp_type.lower() self._validate() def _validate(self): - dp, tp, pp = self.dp, self.tp, self.pp + dp, dp_replicate, tp, pp = self.dp, self.dp_replicate, self.tp, self.pp if dp == -1: self.dp = dp = self.world_size // (tp * pp) assert dp >= 1, dp + assert dp_replicate >= 1 and dp % dp_replicate == 0, (dp, dp_replicate) assert tp >= 1, tp assert pp >= 1, pp assert ( - dp * tp * pp == self.world_size + dp * tp * pp == self.world_size, ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" - assert self.dp_type in ("fsdp", "ddp") + assert self.dp_type in ("fsdp", "ddp", "hsdp") + assert self.dp_type != "hsdp" or dp_replicate > 1, (self.dp_type, dp_replicate) def build_mesh(self, device_type): dims = [] @@ -42,12 +45,26 @@ def build_mesh(self, device_type): for d, name in zip( [self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True ): - if d > 1: + if d <= 1: + continue + + if name != "dp" or self.dp_replicate <= 1: dims.append(d) names.append(name) + continue + + dp_shard = self.dp // self.dp_replicate + dims.extend([self.dp_replicate, dp_shard]) + names.extend(["dp_replicate", "dp_shard"]) + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") names = tuple(names) - return init_device_mesh(device_type, dims, mesh_dim_names=names) + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + # Create all the submesh here to ensure all required process groups are + # initialized + if self.dp_replicate > 1: + mesh["dp_replicate", "dp_shard"]._flatten() + return mesh @property def dp_enabled(self): diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 18a0f452e4..3f951c1205 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -72,9 +72,13 @@ def parallelize_llama( apply_compile(model) if parallel_dims.dp_enabled: - if parallel_dims.dp_type == "fsdp": - dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh - assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names + if parallel_dims.dp_type in ("fsdp", "hsdp"): + if parallel_dims.dp_type == "hsdp": + dp_mesh = world_mesh["dp_replicate", "dp_shard"] + elif world_mesh.ndim > 1: + dp_mesh = world_mesh["dp"] + else: + dp_mesh = world_mesh apply_fsdp( model, diff --git a/train.py b/train.py index d297b8a7a8..0cfb84209f 100644 --- a/train.py +++ b/train.py @@ -29,6 +29,11 @@ ) from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling +try: + from torch.distributed.utils import _sync_module_states_with_mesh +except ImportError: + pass + def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): @contextlib.contextmanager @@ -66,6 +71,7 @@ def main(job_config: JobConfig): world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, dp_type=job_config.training.data_parallel_type, + dp_replicate=job_config.training.data_parallel_replicate_degree, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) @@ -77,7 +83,13 @@ def main(job_config: JobConfig): # build meshes world_mesh = parallel_dims.build_mesh(device_type="cuda") if parallel_dims.dp_enabled: - dp_mesh = world_mesh["dp"] + if parallel_dims.dp_type == "hsdp": + # Both dp_replicate and dp_shard belong to data parallelism and + # we need to flatten them to get the true dp_mesh for the dataloader + # and loss gathering. + dp_mesh = world_mesh["dp_replicate", "dp_shard"]._flatten() + else: + dp_mesh = world_mesh["dp"] dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 @@ -210,6 +222,9 @@ def loss_fn(pred, labels): "All the substages will be initialized with random weights with same RNG state which can affect convergence." ) + if not checkpoint_loaded and parallel_dims.dp_type == "hsdp": + _sync_module_states_with_mesh(model, world_mesh["dp_replicate"]) + metric_logger = build_metric_logger(job_config, parallel_dims) # plot losses loaded from checkpoint (if any) to TensorBoard From 69c964b371d6d748611ababc3d1aece4913ed736 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 15 Aug 2024 11:45:18 -0700 Subject: [PATCH 2/7] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 10 +++++- torchtitan/parallelisms/parallel_dims.py | 41 +++++++++++++----------- train.py | 4 --- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index b1ac76bff4..48d538afd1 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -226,8 +226,14 @@ def __init__(self): self.parser.add_argument( "--training.data_parallel_degree", type=int, + nargs="+", default=-1, - help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.", + help=""" + Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). + 1 means disabled. If HSDP is used, there should be 2 integers. The first + one means the replicate degree and the second one mean the shard degree. + -1 is not supported in HSDP case. + """, ) self.parser.add_argument( "--training.tensor_parallel_degree", @@ -608,6 +614,8 @@ def parse_args_from_command_line( # since the inferred type is just 'list' and it ends up flattening # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...] aux_parser.add_argument("--" + arg, type=string_list) + elif isinstance(val, list): + aux_parser.add_argument("--" + arg, type=type(val[0]), nargs="+") else: aux_parser.add_argument("--" + arg, type=type(val)) diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index d5b23e4248..531dfe2654 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from functools import cached_property +from typing import List, Tuple, Union from torch.distributed.device_mesh import init_device_mesh from torchtitan.logging import logger @@ -13,28 +14,37 @@ @dataclass class ParallelDims: - dp: int + dp: Union[int, List[int]] tp: int pp: int world_size: int enable_loss_parallel: bool dp_type: str = "fsdp" - dp_replicate: int = 1 # Only used when dp_type is hsdp def __post_init__(self): self.dp_type = self.dp_type.lower() self._validate() + def _get_dp(self) -> Tuple[int, int]: + if isinstance(self.dp, (tuple, list)): + return self.dp[0], self.dp[1] + elif self.dp_type == "fsdp": + return 1, self.dp + else: + return self.dp, 1 + def _validate(self): - dp, dp_replicate, tp, pp = self.dp, self.dp_replicate, self.tp, self.pp + dp, tp, pp = self.dp, self.tp, self.pp if dp == -1: self.dp = dp = self.world_size // (tp * pp) - assert dp >= 1, dp - assert dp_replicate >= 1 and dp % dp_replicate == 0, (dp, dp_replicate) + + dp_replicate, dp_shard = self._get_dp() + assert dp_replicate >= 1, self.dp + assert dp_shard >= 1, self.dp assert tp >= 1, tp assert pp >= 1, pp assert ( - dp * tp * pp == self.world_size, + dp_replicate * dp_shard * tp * pp == self.world_size, ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" assert self.dp_type in ("fsdp", "ddp", "hsdp") assert self.dp_type != "hsdp" or dp_replicate > 1, (self.dp_type, dp_replicate) @@ -42,33 +52,28 @@ def _validate(self): def build_mesh(self, device_type): dims = [] names = [] + dp_replicate, dp_shard = self._get_dp() for d, name in zip( - [self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True + [self.pp, dp_replicate, dp_shard, self.tp], + ["pp", "dp_replicate", "dp_shard", "tp"], + strict=True, ): - if d <= 1: - continue - - if name != "dp" or self.dp_replicate <= 1: + if d > 1: dims.append(d) names.append(name) - continue - - dp_shard = self.dp // self.dp_replicate - dims.extend([self.dp_replicate, dp_shard]) - names.extend(["dp_replicate", "dp_shard"]) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") names = tuple(names) mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) # Create all the submesh here to ensure all required process groups are # initialized - if self.dp_replicate > 1: + if dp_replicate > 1: mesh["dp_replicate", "dp_shard"]._flatten() return mesh @property def dp_enabled(self): - return self.dp > 1 + return isinstance(self.dp, list) or self.dp > 1 @property def tp_enabled(self): diff --git a/train.py b/train.py index 0cfb84209f..0e520f37f2 100644 --- a/train.py +++ b/train.py @@ -71,7 +71,6 @@ def main(job_config: JobConfig): world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, dp_type=job_config.training.data_parallel_type, - dp_replicate=job_config.training.data_parallel_replicate_degree, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) @@ -222,9 +221,6 @@ def loss_fn(pred, labels): "All the substages will be initialized with random weights with same RNG state which can affect convergence." ) - if not checkpoint_loaded and parallel_dims.dp_type == "hsdp": - _sync_module_states_with_mesh(model, world_mesh["dp_replicate"]) - metric_logger = build_metric_logger(job_config, parallel_dims) # plot losses loaded from checkpoint (if any) to TensorBoard From 15f0454e7babaa9e7d5b951ee4bce2166f9fb33a Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 16 Aug 2024 16:17:01 -0700 Subject: [PATCH 3/7] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 8 +------- torchtitan/parallelisms/parallel_dims.py | 2 +- train.py | 5 ----- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 48d538afd1..9af3633ccb 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -231,7 +231,7 @@ def __init__(self): help=""" Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled. If HSDP is used, there should be 2 integers. The first - one means the replicate degree and the second one mean the shard degree. + one means the replicate degree and the second one means the shard degree. -1 is not supported in HSDP case. """, ) @@ -324,12 +324,6 @@ def __init__(self): default="fsdp", help="Data parallelism type. TorchTitan currently supports FSDP, HSDP, and DDP.", ) - self.parser.add_argument( - "--training.data_parallel_replicate_degree", - type=int, - default=1, - help="When data_parallel_type is HSDP, data parallelism has 2 different shardings: replicate and shard. This argument specifies the degree of replicate.", - ) self.parser.add_argument( "--experimental.enable_compiled_autograd", action="store_true", diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 531dfe2654..36880a3cb1 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -46,7 +46,7 @@ def _validate(self): assert ( dp_replicate * dp_shard * tp * pp == self.world_size, ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" - assert self.dp_type in ("fsdp", "ddp", "hsdp") + assert self.dp_type in ("fsdp", "ddp", "hsdp"), self.dp_type assert self.dp_type != "hsdp" or dp_replicate > 1, (self.dp_type, dp_replicate) def build_mesh(self, device_type): diff --git a/train.py b/train.py index 0e520f37f2..e1e7e1869b 100644 --- a/train.py +++ b/train.py @@ -29,11 +29,6 @@ ) from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling -try: - from torch.distributed.utils import _sync_module_states_with_mesh -except ImportError: - pass - def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): @contextlib.contextmanager From 5aac73f5209a5b3cfe6b83a5c02701b0b6898c02 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 3 Sep 2024 13:51:49 -0700 Subject: [PATCH 4/7] Update [ghstack-poisoned] --- torchtitan/parallelisms/parallel_dims.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 6d6ec3d2dd..df25214602 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from functools import cached_property -from typing import List, Tuple, Union from torch.distributed.device_mesh import init_device_mesh from torchtitan.logging import logger From 3ffb822f87df897a175c07ae55f56fd5d9a759b3 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 3 Sep 2024 15:24:42 -0700 Subject: [PATCH 5/7] Update [ghstack-poisoned] --- estimation.py | 4 ++-- torchtitan/config_manager.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/estimation.py b/estimation.py index 13ccd4c16b..f58907c6f7 100644 --- a/estimation.py +++ b/estimation.py @@ -64,12 +64,12 @@ def estimate_memory(job_config: JobConfig): job_config.experimental.enable_compiled_autograd = False parallel_dims = ParallelDims( - dp=job_config.training.data_parallel_degree, + dp_shard=job_config.training.data_parallel_shard_degree, + dp_replicate=job_config.training.data_parallel_replicate_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, - dp_type=job_config.training.data_parallel_type, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 74d1346860..86a056fa21 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -252,7 +252,7 @@ def __init__(self): method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the parallelism method used is FSDP (Fully Sharded Data Parallelism). - -1 means leftover ranks will be used (After DP_REPLICATED/SP/PP). Note that + -1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that only one of `data_parallel_replicate_degree` and `data_parallel_shard_degree` can be negative. 1 means disabled.""", From 45d744ad8523d8872ee8c193c4242ed8ca4c77d0 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 9 Sep 2024 23:14:57 -0700 Subject: [PATCH 6/7] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 4 ---- torchtitan/parallelisms/parallel_dims.py | 14 ++++---------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 4ee2806ff2..67c82d53f8 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -234,10 +234,6 @@ def __init__(self): ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the parallelism method used is DDP (Distributed Data Parallelism). - - -1 means leftover ranks will be used (After DP_SHARD/SP/PP). Note that only - one of `data_parallel_replicate_degree` and `data_parallel_shard_degree` can - be negative. 1 means disabled.""", ) self.parser.add_argument( diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 2724c05def..3c13d80ad6 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -30,20 +30,14 @@ def _validate(self): self.tp, self.pp, ) - assert ( - dp_replicate >= -1 and dp_shard >= -1 and dp_replicate * dp_shard != 0 - ), "dp_replicate and dp_shard must -1 or >=1." - assert ( - dp_replicate != -1 or dp_shard != -1 - ), "Only one of dp_replicate, dp_shard can be -1" + for d in (dp_replicate, tp, pp): + assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" + assert dp_shard == -1 or dp_replicate >= 1, " dp_shard must -1 or >=1." dp = dp_replicate * dp_shard if dp < 0: dp = self.world_size // (tp * pp) - if dp_replicate == -1: - self.dp_replicate = dp_replicate = dp // dp_shard - if dp_shard == -1: - self.dp_shard = dp_shard = dp // dp_replicate + self.dp_shard = dp_shard = dp // dp_replicate assert dp_replicate >= 1 assert dp_shard >= 1 From c77b0c20c68627f0256647b8e7ef18a9a04c3eeb Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 10 Sep 2024 09:00:56 -0700 Subject: [PATCH 7/7] Update [ghstack-poisoned] --- test_runner.py | 4 ++-- torchtitan/parallelisms/parallel_dims.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test_runner.py b/test_runner.py index 19dda7e70d..6d706a641b 100755 --- a/test_runner.py +++ b/test_runner.py @@ -286,7 +286,7 @@ def build_test_list(): [ [ "--training.data_parallel_shard_degree=1", - "--training.data_parallel_replicate_degree=-1", + "--training.data_parallel_replicate_degree=4", ] ], "DDP", @@ -297,7 +297,7 @@ def build_test_list(): [ [ "--training.data_parallel_shard_degree=2", - "--training.data_parallel_replicate_degree=-1", + "--training.data_parallel_replicate_degree=2", ] ], "HSDP", diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 3c13d80ad6..2e2aacc75b 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -32,7 +32,7 @@ def _validate(self): ) for d in (dp_replicate, tp, pp): assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" - assert dp_shard == -1 or dp_replicate >= 1, " dp_shard must -1 or >=1." + assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." dp = dp_replicate * dp_shard if dp < 0: