Introduce FSDPv2#6122
Conversation
jonb377
left a comment
There was a problem hiding this comment.
Awesome stuff Jiewen! 👏
| return output | ||
|
|
||
| def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]: | ||
| """Forward missing attributes to wrapped module.""" |
There was a problem hiding this comment.
This actually forwards all attributes defined on the wrapped module, right? Only attributes missing on the wrapped module will be retrieved from the SPMDFullyShardedDataParallel instance.
There was a problem hiding this comment.
I just copy & paste from the existing wrapper. lol. Will need to revisit it.
| mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) | ||
| model.fc1 = fsdp.SpmdFullyShardedDataParallel(model.fc1, mesh) | ||
| model.fc2 = fsdp.SpmdFullyShardedDataParallel(model.fc2, mesh) | ||
| model = fsdp.SpmdFullyShardedDataParallel(model, mesh) |
There was a problem hiding this comment.
Do you mind adding some assertions that the sharding is correct here?
There was a problem hiding this comment.
Is there a way to wrap recursively into the module? This is a question maybe more for the usage.. do you need to wrap each layer and then the module?
There was a problem hiding this comment.
Ok, first wrap the submodules and then wrapping the outer module will take care of the rest.
There was a problem hiding this comment.
Right, the auto-wrap will come later.
| raise RuntimeError( | ||
| f"The output type is not supported: {type(output)}. Please provide your own shard_output callable.") | ||
|
|
||
| spmd.mark_sharding(real_output, mesh, _prepare_spmd_partition_spec(real_output)) |
There was a problem hiding this comment.
It looks like we can use SPMDFullyShardedDataParallel to express data-parallel multislice by specifying a mesh like ('dcn', 'fsdp'), but we'd need to combine the two axes when sharding the activations' batch axis.
We could achieve this by specifying a shard_output function, but what do you think of allowing users to override which axes to shard activations along in the default shard_output_impl? e.g. a new constructor parameter for activation_sharding='fsdp', and allow it to be overridden to activation_sharding=('dcn', 'fsdp')
There was a problem hiding this comment.
Oops, I missed that.
There was a problem hiding this comment.
Let's do this as a follow up.
| def test_fsdp_v2(self): | ||
| model = self.SimpleLinear().to(xm.xla_device()) | ||
| mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) | ||
| model.fc1 = fsdp.SpmdFullyShardedDataParallel(model.fc1, mesh) |
There was a problem hiding this comment.
Do we need to shard the individual layers? Since the full model is wrapped on L27
There was a problem hiding this comment.
I would also go a step further to test 2 cases
- wrap one of the inners and then the outer, check if all are wrapped; let's say the two wrapping uses different shardings and make sure the original sharding is unchanged.
- wrap the outer, and try and except when wrapping the inner
There was a problem hiding this comment.
I hope this use case is well explained in the design doc.
yeounoh
left a comment
There was a problem hiding this comment.
Added some minor comments, thanks!
|
|
||
| class SpmdFullyShardedDataParallel(nn.Module): | ||
|
|
||
| def __init__(self, module: nn.Module, mesh: spmd.Mesh, shard_output:Optional[Callable] = None): |
There was a problem hiding this comment.
do we need to take mesh from caller? If it is FSDP, would it be possible to just assume we just shard at 0th dimension for all devices if mesh is not provided?
|
I have addressed some of the comments and polish the test case a bit. Feel free to re-review it. I will add more test cases once I'm back. |
|
I'm merging it in order to catch the 2.2 backport cutoff. |
Summary: This patch introduce a PoC of FSDPv2. The full design doc is here: go/fsdp_v2. A real world use case can be found: https://github.com/pytorch-tpu/transformers/tree/llama2-spmd-fsdp. Test Plan: python test/spmd/test_fsdp_v2.py
Summary: This patch introduce a PoC of FSDPv2. The full design doc is here: go/fsdp_v2. A real world use case can be found: https://github.com/pytorch-tpu/transformers/tree/llama2-spmd-fsdp. Test Plan: python test/spmd/test_fsdp_v2.py
Summary:
This patch introduce a PoC of FSDPv2. The full design doc is here: go/fsdp_v2. A real world use case can be found: https://github.com/pytorch-tpu/transformers/tree/llama2-spmd-fsdp.
Test Plan:
python test/spmd/test_fsdp_v2.py