Skip to content

Conversation

@kyuyeunk
Copy link
Collaborator

@kyuyeunk kyuyeunk commented Aug 19, 2025

Description

Instead of overriding vLLM layers into a custom JAX layer, this PR aims to utilize existing vLLM layer but use their provided APIs and templates (such as get_quantization_config and process_weights_after_loading) to call our own JAX code.

This will loose some flexibility but has following benefits:

  1. We don’t have to implement everything from scratch and able to leverage pre-existing vLLM APIs.
  2. Plays more nicely with existing vLLM configs / features. Therefore, it helps with customer migration from vLLM GPU to vLLM TPU.

Using this change, I was able to implement FP8 model support for torchax with relative ease - which I'll create a follow-up PR for it.

Tests

Ran following models and verified that performance has not changed

model before after
Qwen/Qwen2.5-14B-Instruct 18.36 18.37
Qwen/Qwen2.5-32B 12.42 12.65
Qwen/Qwen3-32B 12.07 12.14
RedHatAI/Meta-Llama-3.1-70B-Instruct-quantized.w8a8 8.82 8.72
google/gemma-3-27b-it 19.22 19.19
meta-llama/Llama-3.1-70B-Instruct 7.55 7.55
mistralai/Codestral-22B-v0.1 14.94 14.98
mistralai/Mistral-Small-24B-Instruct-2501 20.39 20.85
mistralai/Mixtral-8x7B-Instruct-v0.1 18.51 18.61

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@github-actions
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@kyuyeunk
Copy link
Collaborator Author

@hfan @yaochengji @vanbasten23 work in progress. Pinging to get some early feedback.

@kyuyeunk kyuyeunk force-pushed the torchax_quant_api branch 2 times, most recently from d3a1945 to 577f5cd Compare August 21, 2025 06:47
@kyuyeunk kyuyeunk changed the title Support Quantization API Refactor torchax layers to use vLLM APIs Aug 21, 2025
@kyuyeunk kyuyeunk requested a review from hfan August 21, 2025 06:48
@kyuyeunk kyuyeunk marked this pull request as ready for review August 21, 2025 06:48
@kyuyeunk
Copy link
Collaborator Author

kyuyeunk commented Aug 21, 2025

@hfan @yaochengji @vanbasten23 work in progress. Pinging to get some early feedback.

+ @yarongmu-google @bythew3i @QiliangCui @lsy323

This PR is ready for review. This is a pretty big change so please feel free to request any additional test results or design discussion.

@kyuyeunk kyuyeunk force-pushed the torchax_quant_api branch 4 times, most recently from 28c1e0a to f65a27b Compare August 21, 2025 07:31
@kyuyeunk kyuyeunk force-pushed the torchax_quant_api branch 2 times, most recently from c47c114 to abe3c95 Compare August 21, 2025 16:44
@vanbasten23
Copy link
Collaborator

Considering it's a large change, could you enable the torchax+jax_runner tests in the CI, make sure it pass, then merge? Afaik, the torchax+jax_runner tests are disabled in the CI today.

@yarongmu-google
Copy link
Collaborator

Thanks @kyuyeunk. cc @bvrockwell this is the refactor that brings torchax much closer to vllm upstream to max reuse.

@vanbasten23
Copy link
Collaborator

Just out of curiosity, do you know why we created our own layers such as JaxMergedColumnParallelLinear?

@kyuyeunk kyuyeunk force-pushed the torchax_quant_api branch 3 times, most recently from e8b978d to 5687edb Compare August 27, 2025 04:14
Copy link
Collaborator Author

@kyuyeunk kyuyeunk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hfan As mentioned in one of the comment, I've separated out one of the logic in to a separate PR: #590

I'm intending to merge that PR first and submit this PR.

Additionally, I'm running performance benchmark on all models that we are tracking to confirm performance has not changed. I will report back when I'm finished.

token_num = x.shape[0]
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
if token_num // self.mesh.shape['model'] >= TPU_SECOND_LAST_MINOR:
out.shard_(NamedSharding(self.mesh, P('model', None)))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it!

with torchax.default_env():
with torchax.default_env(), set_vllm_model_wrapper_context(
kv_caches=None, mesh=self.mesh), set_forward_context(
attn_metadata=None, vllm_config=self.vllm_config):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't add set_forward_context, this part of vLLM code complains that forward context is not set: https://github.com/vllm-project/vllm/blob/585e0bde36abdb2ab2967fd42005cbe62459020e/vllm/attention/layer.py#L268

I've considered wrapping set_foward_context into set_vllm_model_wrapper_context so that a single call can set both vllm_model_wrapper_context and forward_context. But I wasn't sure if it's a good UI design so I left it as-is.

@kyuyeunk
Copy link
Collaborator Author

Referenced https://github.com/QiliangCui/bm-infra/blob/main/cases/hourly_torchax_jax.csv and benchmarked following models. Confirmed that the performance has not changed.

model before after
Qwen/Qwen2.5-14B-Instruct 18.36 18.37
Qwen/Qwen2.5-32B 12.42 12.65
Qwen/Qwen3-32B 12.07 12.14
RedHatAI/Meta-Llama-3.1-70B-Instruct-quantized.w8a8 8.82 8.72
google/gemma-3-27b-it 19.22 19.19
meta-llama/Llama-3.1-70B-Instruct 7.55 7.55
mistralai/Codestral-22B-v0.1 14.94 14.98
mistralai/Mistral-Small-24B-Instruct-2501 20.39 19.72
mistralai/Mixtral-8x7B-Instruct-v0.1 18.51 18.61

cc: @hfan

@hfan
Copy link
Collaborator

hfan commented Aug 27, 2025

Referenced https://github.com/QiliangCui/bm-infra/blob/main/cases/hourly_torchax_jax.csv and benchmarked following models. Confirmed that the performance has not changed.

model before after
Qwen/Qwen2.5-14B-Instruct 18.36 18.37
Qwen/Qwen2.5-32B 12.42 12.65
Qwen/Qwen3-32B 12.07 12.14
RedHatAI/Meta-Llama-3.1-70B-Instruct-quantized.w8a8 8.82 8.72
google/gemma-3-27b-it 19.22 19.19
meta-llama/Llama-3.1-70B-Instruct 7.55 7.55
mistralai/Codestral-22B-v0.1 14.94 14.98
mistralai/Mistral-Small-24B-Instruct-2501 20.39 19.72
mistralai/Mixtral-8x7B-Instruct-v0.1 18.51 18.61
cc: @hfan

Thanks! The performance diff for mistralai/Mistral-Small-24B-Instruct-2501 seems bit larger than noise? Is it reproducible?

Copy link
Collaborator Author

@kyuyeunk kyuyeunk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! The performance diff for mistralai/Mistral-Small-24B-Instruct-2501 seems bit larger than noise? Is it reproducible?

Turns out I was using a wrong --max-num-batched-tokens. Fixing it bridged the gap:

model before after
mistralai/Mistral-Small-24B-Instruct-2501 20.39 20.85

with torchax.default_env():
with torchax.default_env(), set_vllm_model_wrapper_context(
kv_caches=None, mesh=self.mesh), set_forward_context(
attn_metadata=None, vllm_config=self.vllm_config):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah. You were right. Deleted set_forward_context.

token_num = x.shape[0]
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
if token_num // self.mesh.shape['model'] >= TPU_SECOND_LAST_MINOR:
out.shard_(NamedSharding(self.mesh, P('model', None)))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, I'm slightly confused what is the difference between sharding input vs. output. After sharding propagation, isn't the end result essentially the same?

Here is an example of RowParallelLinear:

# RowParallelLinear - Shard input

in # ('model', None) 
weight # (None, 'model') 
out  # not specified

After sharding propagation, out is sharded as ('model', None)

in # ('model', None) 
weight # (None, 'model')
weight = allgather(weight, 'model') # (None, None)
out = in * weight # ('model', None)


# RowParallelLinear - Shard output

in # not specified
weight # (None, 'model') 
out  # ('model', None) 

After sharding propagation, in is sharded as ('model', None)

in # ('model', None)
weight # (None, 'model')
weight = allgather(weight, 'model') # (None, None)
out = in * weight # ('model', None)

And here is the scenario for ColumnParallelLinear:

# ColumnParallelLinear - Shard input

in # ('model', None) 
weight # ('model', None) 
out  # not specified

After sharding propagation, out is sharded as ('model', None)

in # ('model', None) 
weight # ('model', None)
weight = allgather(weight, 'model') # (None, None)
out = in * weight # ('model', None)


# ColumnParallelLinear - Shard output

in # not specified
weight # ('model', None)
out  # ('model', None) 

After sharding propagation, in is sharded as ('model', None)

in # ('model', None)
weight # ('model', None)
weight = allgather(weight, 'model') # (None, None)
out = in * weight # ('model', None)

Things might differ based on how previous / next layers are sharded, but that's my understanding. @yaochengji, do you have any insights? Or some unit tests to verify sharding?

@kyuyeunk
Copy link
Collaborator Author

if token_num // self.jax_config.mesh.shape[
'model'] >= TPU_SECOND_LAST_MINOR:
x.shard_(
NamedSharding(self.jax_config.mesh, P('model', None)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can approve this PR except these XLA collective matmul handling.

Will leave it to @yaochengji

@hfan hfan self-requested a review August 28, 2025 19:22
@kyuyeunk kyuyeunk force-pushed the torchax_quant_api branch 4 times, most recently from efa22c2 to 39a7a2e Compare August 29, 2025 05:55
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
@kyuyeunk kyuyeunk changed the title Refactor torchax layers to use vLLM APIs [Refactor torchax layers to use vLLM APIs Aug 29, 2025
@kyuyeunk kyuyeunk changed the title [Refactor torchax layers to use vLLM APIs [Torchax] Refactor torchax layers to use vLLM APIs Aug 29, 2025
Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the awesome improvement!

@kyuyeunk kyuyeunk merged commit 51a2e09 into main Aug 29, 2025
1 of 2 checks passed
@kyuyeunk kyuyeunk deleted the torchax_quant_api branch August 29, 2025 21:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants