Skip to content

Why is the ep mesh derived from a factoring of the dp mesh, instead of its own dimension? #1977

@man2machine

Description

@man2machine

I see that the data parallel shard dimension is factored into two dimensions, dp_shard_mod_ep and dp_shard_in_ep.

The experts use dp_shard_mod_ep submesh for FSDP while the rest of the blocks use the regular dp_shard_cp submesh. Why can't the experts use FSDP on the regular dp_mesh? The reason for this is unclear after reading the code. If only expert parallelism is used without data parallel or if the data parallel size is less than expert parallel, then the dp_shard_mod_ep dimension size would be 0, which doesn't make sense.

Furthermore, the ep submesh is not actually a bona fide actual dimension, but rather a combination of dp_shard_in_ep, cp and tp. Why can't ep be its own dimension? Currently ep is like some weird factored submesh of dp_shard instead of being its own dimension, and I don't understand why.

I understand the combining of various mesh dimensions into dp_shard_cp is used to limit those dimensions to a 1D mesh as FSDP accepts a 1D mesh and HSDP a 2D mesh.

But why can't the mesh dims be for example:

(assuming cp = 1, tp = 1, etp = 1)
world mesh: ['pp', 'dp_replicate', 'dp_shard', 'ep', 'cp', 'tp']
dp_shard mesh: ['dp_shard'] (not flattening of ['dp_shard_in_ep', 'dp_shard_mod_ep']
ep mesh: ['ep'] (not 'dp_shard_in_ep')

Sorry for all the questions I'm just pretty confused as to whats going on. The most important question is why does dp_shard need to be factored into two dimensions? I also think the ._flatten() function should be exposed publicly if so many places use that function.

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions