-
Notifications
You must be signed in to change notification settings - Fork 604
Description
I see that the data parallel shard dimension is factored into two dimensions, dp_shard_mod_ep and dp_shard_in_ep.
The experts use dp_shard_mod_ep submesh for FSDP while the rest of the blocks use the regular dp_shard_cp submesh. Why can't the experts use FSDP on the regular dp_mesh? The reason for this is unclear after reading the code. If only expert parallelism is used without data parallel or if the data parallel size is less than expert parallel, then the dp_shard_mod_ep dimension size would be 0, which doesn't make sense.
Furthermore, the ep submesh is not actually a bona fide actual dimension, but rather a combination of dp_shard_in_ep, cp and tp. Why can't ep be its own dimension? Currently ep is like some weird factored submesh of dp_shard instead of being its own dimension, and I don't understand why.
I understand the combining of various mesh dimensions into dp_shard_cp is used to limit those dimensions to a 1D mesh as FSDP accepts a 1D mesh and HSDP a 2D mesh.
But why can't the mesh dims be for example:
(assuming cp = 1, tp = 1, etp = 1)
world mesh: ['pp', 'dp_replicate', 'dp_shard', 'ep', 'cp', 'tp']
dp_shard mesh: ['dp_shard'] (not flattening of ['dp_shard_in_ep', 'dp_shard_mod_ep']
ep mesh: ['ep'] (not 'dp_shard_in_ep')
Sorry for all the questions I'm just pretty confused as to whats going on. The most important question is why does dp_shard need to be factored into two dimensions? I also think the ._flatten() function should be exposed publicly if so many places use that function.