Skip to content

Conversation

@yaochengji
Copy link
Collaborator

@yaochengji yaochengji commented Aug 20, 2025

Description

  • add sequence parallelism sharding annotation for torchax dense models
  • fix the issue of v6e-4, v6e-8 logical device id order
  • the feature can be enabled by -O '{"pass_config": {"enable_sequence_parallelism": true}}'

Tests

MODEL_IMPL_TYPE=vllm vllm serve Qwen/Qwen2.5-32B  --seed 42  --disable-log-requests  --tensor-parallel-size 8 --max-model-len 2048 --gpu-memory-utilization 0.96 --no-enable-prefix-caching --max-num-seqs 256 --max-num-batched-tokens 4096 -O '{"pass_config": {"enable_sequence_parallelism": true}}' |& tee run.log
python3 ./benchmarks/benchmark_serving.py --model Qwen/Qwen2.5-32B --dataset-name sonnet --dataset-path benchmarks/sonnet_4x.txt --sonnet-input-len 1800 --sonnet-output-len 128 --ignore_eos
MODEL_IMPL_TYPE=vllm vllm serve meta-llama/Llama-3.1-70B-Instruct  --seed 42  --disable-log-requests  --tensor-parallel-size 8 --max-model-len 2048 --gpu-memory-utilization 0.96 --no-enable-prefix-caching --max-num-seqs 128 --max-num-batched-tokens 2048 -O '{"pass_config": {"enable_sequence_parallelism": true}}' |& tee run.log
python3 ./benchmarks/benchmark_serving.py --model meta-llama/Llama-3.1-70B-Instruct --dataset-name sonnet --dataset-path benchmarks/sonnet_4x.txt --sonnet-input-len 1800 --sonnet-output-len 128 --ignore_eos

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@yaochengji yaochengji requested a review from lsy323 August 20, 2025 17:15
@yaochengji
Copy link
Collaborator Author

yaochengji commented Aug 20, 2025

Model Best_throuput_before Throuput_new Improvement
Qwen/Qwen2.5-32B 12.1req/s 14.4req/s 18.9%
Llama-3.1-70B 7.58req/s 8.1req/s 6.9%

@yaochengji
Copy link
Collaborator Author

cc @hfan could you also take a look?

cc @xiangxu-google , you may also apply this on jax models.

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
@hfan
Copy link
Collaborator

hfan commented Aug 20, 2025

@kyuyeunk - this affects your refactor too?


def forward(self, input: torch.Tensor):
with jax.named_scope(self.name):
if self.enable_sequence_parallelism:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit test probably is needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intended to add an e2e test, but cannot find a good example.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests/models/vllm/test_jax_XXX_linear.py are the existing unit tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @hfan ! The unit test is added, manually triggered in https://buildkite.com/tpu-commons/tpu-commons-ci/builds/2043

@hfan
Copy link
Collaborator

hfan commented Aug 20, 2025

Anything prevent it being turned on by default?

@yaochengji
Copy link
Collaborator Author

Anything prevent it being turned on by default?

@hfan I'm not sure if this can benefit all the models, maybe @QiliangCui can first try this in auto-tuning

@kyuyeunk
Copy link
Collaborator

@kyuyeunk - this affects your refactor too?

Yes, this does get affected. @yaochengji, I have created a draft PR so you can take an early look: #512

@hfan
Copy link
Collaborator

hfan commented Aug 20, 2025

Anything prevent it being turned on by default?

@hfan I'm not sure if this can benefit all the models, maybe @QiliangCui can first try this in auto-tuning

As discussed offline, quantized matmuls that uses kernels whose collectives can't be automatically handled by XLA, so we probably have to wait for your collective-matmul kernel and better be safe (or only make it default for the non-quantized code path)

Chengji Yao and others added 2 commits August 20, 2025 20:02
Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
@yaochengji yaochengji merged commit 1badfe8 into main Aug 20, 2025
2 checks passed
with jax.named_scope(self.name):
if self.enable_sequence_parallelism:
token_num = input.shape[0]
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If "sharded token_num is larger than TPU_SECOND_LAST_MINOR", I guess the downside is waste of memory?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be more communication in the final result, you can have a try.

# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
if token_num // self.mesh.shape[
'model'] >= TPU_SECOND_LAST_MINOR:
input.shard_(NamedSharding(self.mesh, P('model', None)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very neat and convenient way of doing self.apply_jax_(jax.lax.with_sharding_constraint, sharding)

if device_num == 8:
ordered_devices = np.array([
devices[0],
devices[2],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why (1,0,0) maps to device [2].

Also, does the order

            # (0,0,0)
            # (1,0,0)
            # (0,1,0)
            # (1,1,0)
            # (0,2,0)
            # (1,2,0)
            # (0,3,0)
            # (1,3,0)

matter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, from the order, it matters

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does (1,0,0) maps to device [2]?

output = merged_column_linear(input_tensor).to(dtype)

# Set jax default device to workaround a layout bug in JAX 0.7.0 and earlier
with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it using one TPU device? Do you need to test when there are multiple TPU device?

# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
if token_num // self.mesh.shape[
'model'] >= TPU_SECOND_LAST_MINOR:
input.shard_(NamedSharding(self.mesh, P('model', None)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that SP is implemented by sharding the num_tokens dimension. Do you need to do an all-gather at the very end? I couldn't find it in your pr.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sharding propagation can handle this and make it correct.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chatted offline, all-gather is not needed because some later ops (e.g. select which token to get logit (need to select num_reqs tokens)) may hint the compiler to get the global view. At that time, compiler will do a all-gather implicitly.

# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
if token_num // self.mesh.shape[
'model'] >= TPU_SECOND_LAST_MINOR:
output.shard_(NamedSharding(self.mesh, P('model', None)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't sequence parallelism usually use 'data' to shard batch dim? Meaning, we can use both sequence and model parallelism and shard the inputs/outputs using P('data', 'model')

Copy link
Collaborator Author

@yaochengji yaochengji Aug 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, only the mesh "model" axis is enough. Here the sequence parallelism is applied on layer_norm, and model parallelism is applied on matmul. It is described in this paper: https://arxiv.org/abs/2205.05198

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants