-
Notifications
You must be signed in to change notification settings - Fork 619
[RFC] Enable HSDP #518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Enable HSDP #518
Changes from all commits
85c607a
69c964b
15f0454
050b9d0
b7ce338
5aac73f
3ffb822
5f69afb
45d744a
c77b0c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,45 +13,78 @@ | |
|
|
||
| @dataclass | ||
| class ParallelDims: | ||
| dp: int | ||
| dp_replicate: int | ||
| dp_shard: int | ||
| tp: int | ||
| pp: int | ||
| world_size: int | ||
| enable_loss_parallel: bool | ||
| dp_type: str | ||
|
|
||
| def __post_init__(self): | ||
| self.dp_type = self.dp_type.lower() | ||
| self._validate() | ||
|
|
||
| def _validate(self): | ||
| dp, tp, pp = self.dp, self.tp, self.pp | ||
| if dp == -1: | ||
| self.dp = dp = self.world_size // (tp * pp) | ||
| assert dp >= 1, dp | ||
| dp_replicate, dp_shard, tp, pp = ( | ||
| self.dp_replicate, | ||
| self.dp_shard, | ||
| self.tp, | ||
| self.pp, | ||
| ) | ||
| for d in (dp_replicate, tp, pp): | ||
| assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" | ||
| assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." | ||
|
|
||
| dp = dp_replicate * dp_shard | ||
| if dp < 0: | ||
| dp = self.world_size // (tp * pp) | ||
| self.dp_shard = dp_shard = dp // dp_replicate | ||
|
|
||
| assert dp_replicate >= 1 | ||
| assert dp_shard >= 1 | ||
| assert tp >= 1, tp | ||
| assert pp >= 1, pp | ||
| assert ( | ||
| dp * tp * pp == self.world_size | ||
| ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" | ||
| assert self.dp_type in ("fsdp", "ddp") | ||
| assert dp_replicate * dp_shard * tp * pp == self.world_size, ( | ||
| f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " | ||
| f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" | ||
| ) | ||
|
|
||
| def build_mesh(self, device_type): | ||
| dims = [] | ||
| names = [] | ||
| for d, name in zip( | ||
| [self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True | ||
| [self.pp, self.dp_replicate, self.dp_shard, self.tp], | ||
| ["pp", "dp_replicate", "dp_shard", "tp"], | ||
| strict=True, | ||
| ): | ||
| if d > 1: | ||
| dims.append(d) | ||
| names.append(name) | ||
| if (name == "dp_replicate" and self.dp_shard == 1) or ( | ||
| name == "dp_shard" and self.dp_replicate == 1 | ||
| ): | ||
| names.append("dp") | ||
| else: | ||
| names.append(name) | ||
|
|
||
| logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") | ||
| names = tuple(names) | ||
| return init_device_mesh(device_type, dims, mesh_dim_names=names) | ||
| mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) | ||
| # Create all the submesh here to ensure all required process groups are | ||
| # initialized | ||
| if self.dp_replicate > 1 and self.dp_shard > 1: | ||
| mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp") | ||
|
Comment on lines
+73
to
+74
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May I ask when and why do we need the flattened "dp" mesh? Is it just for HSDP?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, DP is needed for dataloader and loss computation. It's easier for dataloader and loss computation to only know DP. So I ensure there always exist DP mesh.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh makes sense! |
||
| return mesh | ||
|
|
||
| @property | ||
| def dp_enabled(self): | ||
| return self.dp > 1 | ||
| return self.dp_replicate > 1 or self.dp_shard > 1 | ||
|
|
||
| @property | ||
| def dp_replicate_enabled(self): | ||
| return self.dp_replicate > 1 | ||
|
|
||
| @property | ||
| def dp_shard_enabled(self): | ||
| return self.dp_shard > 1 | ||
|
|
||
| @property | ||
| def tp_enabled(self): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe don't change, but it is not obvious if we need to add 'dp'. What is the downside of leaving original names? 'dp_replicate' is clearer than 'dp' if someone is looking at PG names and wondering what parallelism is used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add
dpbecause for loss computation and dataloader,dpis required, whetherdpis dp_replicate + dp_shard (HSDP) ordp_shard(FSDP). These two components care only aboutdp.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense that the device mesh will have the axis:
dpwhen DDP or FSDP is used;dp_shardanddp_replicateas well as their flatteneddpwhen HSDP is used.One corner case is
self.world_size == tp * ppwhere twodpwill be added tonames.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I don't think so, if both dp_replicate and dp_shard are 1,line 59,
if d > 1won't be true. So we will never adddpmesh.