-
Notifications
You must be signed in to change notification settings - Fork 36
[Torchax] Refactor torchax layers to use vLLM APIs #512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure:
|
d5f5e64 to
409c10e
Compare
|
@hfan @yaochengji @vanbasten23 work in progress. Pinging to get some early feedback. |
d3a1945 to
577f5cd
Compare
+ @yarongmu-google @bythew3i @QiliangCui @lsy323 This PR is ready for review. This is a pretty big change so please feel free to request any additional test results or design discussion. |
28c1e0a to
f65a27b
Compare
c47c114 to
abe3c95
Compare
tests/models/vllm/quantization/test_compressed_tensors_w8a8_int8.py
Outdated
Show resolved
Hide resolved
|
Considering it's a large change, could you enable the torchax+jax_runner tests in the CI, make sure it pass, then merge? Afaik, the torchax+jax_runner tests are disabled in the CI today. |
|
Thanks @kyuyeunk. cc @bvrockwell this is the refactor that brings torchax much closer to vllm upstream to max reuse. |
|
Just out of curiosity, do you know why we created our own layers such as JaxMergedColumnParallelLinear? |
abe3c95 to
597b58a
Compare
e8b978d to
5687edb
Compare
kyuyeunk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hfan As mentioned in one of the comment, I've separated out one of the logic in to a separate PR: #590
I'm intending to merge that PR first and submit this PR.
Additionally, I'm running performance benchmark on all models that we are tracking to confirm performance has not changed. I will report back when I'm finished.
| token_num = x.shape[0] | ||
| # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR | ||
| if token_num // self.mesh.shape['model'] >= TPU_SECOND_LAST_MINOR: | ||
| out.shard_(NamedSharding(self.mesh, P('model', None))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed it!
tpu_commons/models/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
Outdated
Show resolved
Hide resolved
tpu_commons/models/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
Outdated
Show resolved
Hide resolved
| with torchax.default_env(): | ||
| with torchax.default_env(), set_vllm_model_wrapper_context( | ||
| kv_caches=None, mesh=self.mesh), set_forward_context( | ||
| attn_metadata=None, vllm_config=self.vllm_config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't add set_forward_context, this part of vLLM code complains that forward context is not set: https://github.com/vllm-project/vllm/blob/585e0bde36abdb2ab2967fd42005cbe62459020e/vllm/attention/layer.py#L268
I've considered wrapping set_foward_context into set_vllm_model_wrapper_context so that a single call can set both vllm_model_wrapper_context and forward_context. But I wasn't sure if it's a good UI design so I left it as-is.
tests/models/vllm/quantization/test_compressed_tensors_w8a8_int8.py
Outdated
Show resolved
Hide resolved
5687edb to
378adb4
Compare
|
Referenced https://github.com/QiliangCui/bm-infra/blob/main/cases/hourly_torchax_jax.csv and benchmarked following models. Confirmed that the performance has not changed.
cc: @hfan |
378adb4 to
f0bb371
Compare
Thanks! The performance diff for |
f0bb371 to
8905c17
Compare
kyuyeunk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! The performance diff for mistralai/Mistral-Small-24B-Instruct-2501 seems bit larger than noise? Is it reproducible?
Turns out I was using a wrong --max-num-batched-tokens. Fixing it bridged the gap:
| model | before | after |
|---|---|---|
| mistralai/Mistral-Small-24B-Instruct-2501 | 20.39 | 20.85 |
| with torchax.default_env(): | ||
| with torchax.default_env(), set_vllm_model_wrapper_context( | ||
| kv_caches=None, mesh=self.mesh), set_forward_context( | ||
| attn_metadata=None, vllm_config=self.vllm_config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah. You were right. Deleted set_forward_context.
| token_num = x.shape[0] | ||
| # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR | ||
| if token_num // self.mesh.shape['model'] >= TPU_SECOND_LAST_MINOR: | ||
| out.shard_(NamedSharding(self.mesh, P('model', None))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, I'm slightly confused what is the difference between sharding input vs. output. After sharding propagation, isn't the end result essentially the same?
Here is an example of RowParallelLinear:
# RowParallelLinear - Shard input
in # ('model', None)
weight # (None, 'model')
out # not specified
After sharding propagation, out is sharded as ('model', None)
in # ('model', None)
weight # (None, 'model')
weight = allgather(weight, 'model') # (None, None)
out = in * weight # ('model', None)
# RowParallelLinear - Shard output
in # not specified
weight # (None, 'model')
out # ('model', None)
After sharding propagation, in is sharded as ('model', None)
in # ('model', None)
weight # (None, 'model')
weight = allgather(weight, 'model') # (None, None)
out = in * weight # ('model', None)
And here is the scenario for ColumnParallelLinear:
# ColumnParallelLinear - Shard input
in # ('model', None)
weight # ('model', None)
out # not specified
After sharding propagation, out is sharded as ('model', None)
in # ('model', None)
weight # ('model', None)
weight = allgather(weight, 'model') # (None, None)
out = in * weight # ('model', None)
# ColumnParallelLinear - Shard output
in # not specified
weight # ('model', None)
out # ('model', None)
After sharding propagation, in is sharded as ('model', None)
in # ('model', None)
weight # ('model', None)
weight = allgather(weight, 'model') # (None, None)
out = in * weight # ('model', None)
Things might differ based on how previous / next layers are sharded, but that's my understanding. @yaochengji, do you have any insights? Or some unit tests to verify sharding?
tests/models/vllm/quantization/test_compressed_tensors_w8a8_int8.py
Outdated
Show resolved
Hide resolved
8905c17 to
67e1c1e
Compare
|
Confirmed that CI passes: https://buildkite.com/tpu-commons/tpu-commons-ci/builds/2429#0198ee7d-539b-4556-8db4-315e40888e7f |
67e1c1e to
e44261a
Compare
| if token_num // self.jax_config.mesh.shape[ | ||
| 'model'] >= TPU_SECOND_LAST_MINOR: | ||
| x.shard_( | ||
| NamedSharding(self.jax_config.mesh, P('model', None))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can approve this PR except these XLA collective matmul handling.
Will leave it to @yaochengji
efa22c2 to
39a7a2e
Compare
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
39a7a2e to
b563730
Compare
yaochengji
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the awesome improvement!
Description
Instead of overriding vLLM layers into a custom JAX layer, this PR aims to utilize existing vLLM layer but use their provided APIs and templates (such as
get_quantization_configandprocess_weights_after_loading) to call our own JAX code.This will loose some flexibility but has following benefits:
Using this change, I was able to implement FP8 model support for torchax with relative ease - which I'll create a follow-up PR for it.
Tests
Ran following models and verified that performance has not changed
Checklist
Before submitting this PR, please make sure: