Skip to content

Introduce FSDPv2#6122

Merged
alanwaketan merged 5 commits intomasterfrom
alanwaketan/fsdp_v2
Dec 15, 2023
Merged

Introduce FSDPv2#6122
alanwaketan merged 5 commits intomasterfrom
alanwaketan/fsdp_v2

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

@alanwaketan alanwaketan commented Dec 12, 2023

Summary:
This patch introduce a PoC of FSDPv2. The full design doc is here: go/fsdp_v2. A real world use case can be found: https://github.com/pytorch-tpu/transformers/tree/llama2-spmd-fsdp.

Test Plan:
python test/spmd/test_fsdp_v2.py

@alanwaketan alanwaketan self-assigned this Dec 12, 2023
Copy link
Copy Markdown
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome stuff Jiewen! 👏

Comment thread test/spmd/test_fsdp_v2.py Outdated
return output

def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]:
"""Forward missing attributes to wrapped module."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually forwards all attributes defined on the wrapped module, right? Only attributes missing on the wrapped module will be retrieved from the SPMDFullyShardedDataParallel instance.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copy & paste from the existing wrapper. lol. Will need to revisit it.

Comment thread test/spmd/test_fsdp_v2.py Outdated
mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor'))
model.fc1 = fsdp.SpmdFullyShardedDataParallel(model.fc1, mesh)
model.fc2 = fsdp.SpmdFullyShardedDataParallel(model.fc2, mesh)
model = fsdp.SpmdFullyShardedDataParallel(model, mesh)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind adding some assertions that the sharding is correct here?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to wrap recursively into the module? This is a question maybe more for the usage.. do you need to wrap each layer and then the module?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, first wrap the submodules and then wrapping the outer module will take care of the rest.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the auto-wrap will come later.

raise RuntimeError(
f"The output type is not supported: {type(output)}. Please provide your own shard_output callable.")

spmd.mark_sharding(real_output, mesh, _prepare_spmd_partition_spec(real_output))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we can use SPMDFullyShardedDataParallel to express data-parallel multislice by specifying a mesh like ('dcn', 'fsdp'), but we'd need to combine the two axes when sharding the activations' batch axis.

We could achieve this by specifying a shard_output function, but what do you think of allowing users to override which axes to shard activations along in the default shard_output_impl? e.g. a new constructor parameter for activation_sharding='fsdp', and allow it to be overridden to activation_sharding=('dcn', 'fsdp')

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I missed that.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this as a follow up.

Comment thread test/spmd/test_fsdp_v2.py Outdated
def test_fsdp_v2(self):
model = self.SimpleLinear().to(xm.xla_device())
mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor'))
model.fc1 = fsdp.SpmdFullyShardedDataParallel(model.fc1, mesh)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to shard the individual layers? Since the full model is wrapped on L27

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also go a step further to test 2 cases

  • wrap one of the inners and then the outer, check if all are wrapped; let's say the two wrapping uses different shardings and make sure the original sharding is unchanged.
  • wrap the outer, and try and except when wrapping the inner

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope this use case is well explained in the design doc.

Comment thread torch_xla/experimental/spmd_fully_sharded_data_parallel.py Outdated
Comment thread torch_xla/experimental/spmd_fully_sharded_data_parallel.py Outdated
Copy link
Copy Markdown
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some minor comments, thanks!

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

@jonb377 and @yeounoh thanks for the quick review. A lot of the things will become more clear once the design doc is ready. Will keep you posted.


class SpmdFullyShardedDataParallel(nn.Module):

def __init__(self, module: nn.Module, mesh: spmd.Mesh, shard_output:Optional[Callable] = None):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to take mesh from caller? If it is FSDP, would it be possible to just assume we just shard at 0th dimension for all devices if mesh is not provided?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is that the mesh will need to share with the dataloader. And a global mesh is a must for SPMD. I was thinking maybe we can have an API to set a global mesh. @jonb377 @yeounoh

@alanwaketan alanwaketan changed the title [WIP] Introduce FSDPv2 Introduce FSDPv2 Dec 15, 2023
@alanwaketan
Copy link
Copy Markdown
Collaborator Author

I have addressed some of the comments and polish the test case a bit. Feel free to re-review it. I will add more test cases once I'm back.

@alanwaketan alanwaketan merged commit 4fe9fe7 into master Dec 15, 2023
@alanwaketan
Copy link
Copy Markdown
Collaborator Author

I'm merging it in order to catch the 2.2 backport cutoff.

golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
Summary:
This patch introduce a PoC of FSDPv2. The full design doc is here: go/fsdp_v2. A real world use case can be found: https://github.com/pytorch-tpu/transformers/tree/llama2-spmd-fsdp.

Test Plan:
python test/spmd/test_fsdp_v2.py
@alanwaketan alanwaketan mentioned this pull request Jan 25, 2024
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
Summary:
This patch introduce a PoC of FSDPv2. The full design doc is here: go/fsdp_v2. A real world use case can be found: https://github.com/pytorch-tpu/transformers/tree/llama2-spmd-fsdp.

Test Plan:
python test/spmd/test_fsdp_v2.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants