-
Notifications
You must be signed in to change notification settings - Fork 36
Support sequence parallelism and collective matmul #520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
cc @hfan could you also take a look? cc @xiangxu-google , you may also apply this on jax models. |
Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
|
@kyuyeunk - this affects your refactor too? |
|
|
||
| def forward(self, input: torch.Tensor): | ||
| with jax.named_scope(self.name): | ||
| if self.enable_sequence_parallelism: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unit test probably is needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I intended to add an e2e test, but cannot find a good example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/models/vllm/test_jax_XXX_linear.py are the existing unit tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @hfan ! The unit test is added, manually triggered in https://buildkite.com/tpu-commons/tpu-commons-ci/builds/2043
|
Anything prevent it being turned on by default? |
@hfan I'm not sure if this can benefit all the models, maybe @QiliangCui can first try this in auto-tuning |
Yes, this does get affected. @yaochengji, I have created a draft PR so you can take an early look: #512 |
As discussed offline, quantized matmuls that uses kernels whose collectives can't be automatically handled by XLA, so we probably have to wait for your collective-matmul kernel and better be safe (or only make it default for the non-quantized code path) |
Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
| with jax.named_scope(self.name): | ||
| if self.enable_sequence_parallelism: | ||
| token_num = input.shape[0] | ||
| # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If "sharded token_num is larger than TPU_SECOND_LAST_MINOR", I guess the downside is waste of memory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There will be more communication in the final result, you can have a try.
| # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR | ||
| if token_num // self.mesh.shape[ | ||
| 'model'] >= TPU_SECOND_LAST_MINOR: | ||
| input.shard_(NamedSharding(self.mesh, P('model', None))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very neat and convenient way of doing self.apply_jax_(jax.lax.with_sharding_constraint, sharding)
| if device_num == 8: | ||
| ordered_devices = np.array([ | ||
| devices[0], | ||
| devices[2], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder why (1,0,0) maps to device [2].
Also, does the order
# (0,0,0)
# (1,0,0)
# (0,1,0)
# (1,1,0)
# (0,2,0)
# (1,2,0)
# (0,3,0)
# (1,3,0)
matter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, from the order, it matters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does (1,0,0) maps to device [2]?
| output = merged_column_linear(input_tensor).to(dtype) | ||
|
|
||
| # Set jax default device to workaround a layout bug in JAX 0.7.0 and earlier | ||
| with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it using one TPU device? Do you need to test when there are multiple TPU device?
| # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR | ||
| if token_num // self.mesh.shape[ | ||
| 'model'] >= TPU_SECOND_LAST_MINOR: | ||
| input.shard_(NamedSharding(self.mesh, P('model', None))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that SP is implemented by sharding the num_tokens dimension. Do you need to do an all-gather at the very end? I couldn't find it in your pr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sharding propagation can handle this and make it correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chatted offline, all-gather is not needed because some later ops (e.g. select which token to get logit (need to select num_reqs tokens)) may hint the compiler to get the global view. At that time, compiler will do a all-gather implicitly.
| # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR | ||
| if token_num // self.mesh.shape[ | ||
| 'model'] >= TPU_SECOND_LAST_MINOR: | ||
| output.shard_(NamedSharding(self.mesh, P('model', None))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't sequence parallelism usually use 'data' to shard batch dim? Meaning, we can use both sequence and model parallelism and shard the inputs/outputs using P('data', 'model')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, only the mesh "model" axis is enough. Here the sequence parallelism is applied on layer_norm, and model parallelism is applied on matmul. It is described in this paper: https://arxiv.org/abs/2205.05198
Description
-O '{"pass_config": {"enable_sequence_parallelism": true}}'Tests
Checklist
Before submitting this PR, please make sure: