-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce tensor sharding #14
Introduce tensor sharding #14
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks Jiewen!
mesh = xs.Mesh(device_ids, (num_devices, 1)) | ||
sharding_spec = xs.ShardingSpec(mesh, (0, 1)) | ||
elif self.args.spmd_tensor_sharding > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specifying batch
or fsdp
sharding will silently override the tensor parallelism, can we assert that these flags are exclusive since tensor_sharding
implies FSDP/batch sharding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it's intended as you can do tensor_sharding on weights and batch sharding on the input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see - so we can run 2D FSDP by specifying --spmd_batch_sharding
and e.g. --spmd_tensor_sharding 4
. I think specifying --spmd_fsdp_sharding
with --spmd_tensor_sharding 4
will always ignore the tensor_sharding
though, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I mean is that you can specify:
--spmd_batch_sharding --spmd_tensor_sharding 4
but not
--spmd_fsdp_sharding --spmd_tensor_sharding 4
Do you think that's clear? If not, I can do a follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thanks Jiewen! I think what we have now is fine. We can follow up later to make the sharding more standard like MaxText has with ici_*_parallelism
and dcn_*_parallelism
parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, for sure. Does HybridMesh does anything for you in a single slice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it will rearrange the tiling assignment to optimize for the ICI connections. I would say we should always use HybridMesh, even for single slice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to see if you have the MFU numbers to compare. I will make a change later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll do a quick test on v4-8 to get the MFU difference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries. It's not a priority.
Thanks Jon for approving. |
Summary: This pull request introduce a new way to do sharding which allow weights to be sharded in two dimensional mesh, i.e., (fsdp, tensor), and then the input to be sharded according to the fsdp dimension. To enable it, pass --spmd_tensor_sharding 2, 2 is the tensor dimension, the fsdp dimension will be auto calculated according to num_devices // 2. Test Plan: Test it on a V4-8 with 2B LLaMA.
Summary: This pull request introduce a new way to do sharding which allow weights to be sharded in two dimensional mesh, i.e., (fsdp, tensor), and then the input to be sharded according to the fsdp dimension. To enable it, pass --spmd_tensor_sharding 2, 2 is the tensor dimension, the fsdp dimension will be auto calculated according to num_devices // 2. Test Plan: Test it on a V4-8 with 2B LLaMA.
* Cohere Model Release (#1) Cohere Model Release * Remove unnecessary files and code (#2) Some cleanup * Delete cohere-model directory (#3) * Make Fix (#5) * Pr fixes (#6) * fixes for pr * pr fixes for the format * pr fixes for the format * src/transformers/models/auto/tokenization_auto.py * Tokenizer test (#8) * tokenizer test * format fix * Adding Docs and other minor changes (#7) * Add modeling tests (#9) * Smol Fix (#11) * tokenization tests are fixed * format fixes * fix pr doc tests * fix pr doc tests * fix pr doc tests * fix pr style check * small changes in cohere.md * FIX: Address final comments for transformers integration (#13) * fix modeling final nits and add proper test file * for now leave empty tests * add integration test * push new test * fix modeling cohere (#14) * Update chat templates to use the new API (#15) --------- Co-authored-by: ahmetustun <ahmetustun89@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Summary: This pull request introduce a new way to do sharding which allow weights to be sharded in two dimensional mesh, i.e., (fsdp, tensor), and then the input to be sharded according to the fsdp dimension. To enable it, pass --spmd_tensor_sharding 2, 2 is the tensor dimension, the fsdp dimension will be auto calculated according to num_devices // 2. Test Plan: Test it on a V4-8 with 2B LLaMA.
Summary:
This pull request introduce a new way to do sharding which allow weights to be sharded in two dimensional mesh, i.e., (fsdp, tensor), and then the input to be sharded according to the fsdp dimension.
To enable it, pass
--spmd_tensor_sharding 2
, 2 is the tensor dimension, the fsdp dimension will be auto calculated according to num_devices // 2.Test Plan:
Test it on a V4-8 with 2B LLaMA.