Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions tests/models/vllm/test_jax_merged_column_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def setup_environment():
@pytest.mark.parametrize("bias", [False, True])
@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()])
@pytest.mark.parametrize("fuse_matmuls", [False, True])
def test_jax_merged_column_parallel_linear(bias, mesh, fuse_matmuls):
@pytest.mark.parametrize("enable_sp", [False, True])
def test_jax_merged_column_parallel_linear(bias, mesh, fuse_matmuls,
enable_sp):
dtype = torch.bfloat16

merged_column_linear = MergedColumnParallelLinear(
Expand All @@ -78,7 +80,10 @@ def test_jax_merged_column_parallel_linear(bias, mesh, fuse_matmuls):
# Set jax default device to workaround a layout bug in JAX 0.7.0 and earlier
with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it using one TPU device? Do you need to test when there are multiple TPU device?

jax_merged_column_linear = JaxMergedColumnParallelLinear(
merged_column_linear, mesh, fuse_matmuls)
merged_column_linear,
mesh,
fuse_matmuls,
enable_sequence_parallelism=enable_sp)
jax_input_tensor = torch_view(t2j(input_tensor))
jax_input_tensor.apply_jax_(jax.device_put,
NamedSharding(mesh, P(None, None)))
Expand Down
6 changes: 4 additions & 2 deletions tests/models/vllm/test_jax_row_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def setup_environment():

@pytest.mark.parametrize("bias", [False, True])
@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()])
def test_jax_row_parallel_linear(bias, mesh):
@pytest.mark.parametrize("enable_sp", [False, True])
def test_jax_row_parallel_linear(bias, mesh, enable_sp):
dtype = torch.bfloat16

engine_args = EngineArgs(
Expand Down Expand Up @@ -82,7 +83,8 @@ def test_jax_row_parallel_linear(bias, mesh):

# Set jax default device to workaround a layout bug in JAX 0.7.0 and earlier
with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]):
jax_row_linear = JaxRowParallelLinear(row_linear, mesh=mesh)
jax_row_linear = JaxRowParallelLinear(
row_linear, mesh=mesh, enable_sequence_parallelism=enable_sp)
jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
jax_input_tensor.apply_jax_(jax.device_put,
NamedSharding(mesh, P(None, None)))
Expand Down
12 changes: 7 additions & 5 deletions tpu_commons/models/vllm/jax_merged_column_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
class JaxMergedColumnParallelLinear(JaxMergedColumnParallelLinearCore):

def __init__(self, merged_col_parallel_linear: torch.nn.Module, mesh: Mesh,
fuse_matmuls: bool):
fuse_matmuls: bool, enable_sequence_parallelism: bool):
assert isinstance(merged_col_parallel_linear,
MergedColumnParallelLinear)
super().__init__(merged_col_parallel_linear,
mesh,
"JaxMergedColumnParallelLinear",
fuse_matmuls=fuse_matmuls)
super().__init__(
merged_col_parallel_linear,
mesh,
"JaxMergedColumnParallelLinear",
fuse_matmuls=fuse_matmuls,
enable_sequence_parallelism=enable_sequence_parallelism)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
forward_unqunatized, forward_w8a8_int8_col_parallel,
reorder_concatenated_tensor_for_sharding,
slice_sharded_tensor_for_concatenation)
from tpu_commons.utils import TPU_SECOND_LAST_MINOR

P = PartitionSpec

Expand All @@ -22,7 +23,7 @@ class JaxMergedColumnParallelLinearCore(torch.nn.Module):
""" A common class to implement Column Parallel Linear layer whose weight are merged from a list of smaller weight tensors, e.g. vLLM's MergedColumnParallelLinear and QKVParallelLinear layer. """

def __init__(self, vllm_col_par_linear: torch.nn.Module, mesh: Mesh,
name: str, fuse_matmuls: bool):
name: str, fuse_matmuls: bool, enable_sequence_parallelism):
super().__init__()

self.gather_output = vllm_col_par_linear.gather_output
Expand All @@ -33,6 +34,7 @@ def __init__(self, vllm_col_par_linear: torch.nn.Module, mesh: Mesh,
self.name = name
self.fuse_matmuls = fuse_matmuls
self.has_bias = vllm_col_par_linear.bias is not None
self.enable_sequence_parallelism = enable_sequence_parallelism
self.n_matmuls = len(self.output_sizes)
assert vllm_col_par_linear.tp_size == 1, (
"The model has to be loaded with TP== 1 in order to run in Jax SPMD."
Expand Down Expand Up @@ -230,7 +232,14 @@ def forward_split(self, input):

def forward(self, input: torch.Tensor):
with jax.named_scope(self.name):
if self.enable_sequence_parallelism:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit test probably is needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intended to add an e2e test, but cannot find a good example.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests/models/vllm/test_jax_XXX_linear.py are the existing unit tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @hfan ! The unit test is added, manually triggered in https://buildkite.com/tpu-commons/tpu-commons-ci/builds/2043

token_num = input.shape[0]
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If "sharded token_num is larger than TPU_SECOND_LAST_MINOR", I guess the downside is waste of memory?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be more communication in the final result, you can have a try.

if token_num // self.mesh.shape[
'model'] >= TPU_SECOND_LAST_MINOR:
input.shard_(NamedSharding(self.mesh, P('model', None)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very neat and convenient way of doing self.apply_jax_(jax.lax.with_sharding_constraint, sharding)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that SP is implemented by sharding the num_tokens dimension. Do you need to do an all-gather at the very end? I couldn't find it in your pr.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sharding propagation can handle this and make it correct.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chatted offline, all-gather is not needed because some later ops (e.g. select which token to get logit (need to select num_reqs tokens)) may hint the compiler to get the global view. At that time, compiler will do a all-gather implicitly.

if self.fuse_matmuls:
return self.forward_fused(input)
output, output_bias = self.forward_fused(input)
else:
return self.forward_split(input)
output, output_bias = self.forward_split(input)
return output, output_bias
12 changes: 7 additions & 5 deletions tpu_commons/models/vllm/jax_qkv_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
class JaxQKVParallelLinear(JaxMergedColumnParallelLinearCore):

def __init__(self, qkv_linear: torch.nn.Module, mesh: Mesh,
fuse_matmuls: bool):
fuse_matmuls: bool, enable_sequence_parallelism: bool):
assert isinstance(qkv_linear, QKVParallelLinear)
super().__init__(qkv_linear,
mesh,
"JaxQKVParallelLinear",
fuse_matmuls=fuse_matmuls)
super().__init__(
qkv_linear,
mesh,
"JaxQKVParallelLinear",
fuse_matmuls=fuse_matmuls,
enable_sequence_parallelism=enable_sequence_parallelism)
11 changes: 10 additions & 1 deletion tpu_commons/models/vllm/jax_row_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,23 @@

from tpu_commons.models.vllm.jax_linear_common import (
forward_unqunatized, forward_w8a8_int8_row_parallel)
from tpu_commons.utils import TPU_SECOND_LAST_MINOR

P = PartitionSpec


class JaxRowParallelLinear(torch.nn.Module):

def __init__(self, row_linear: torch.nn.Module, mesh: Mesh):
def __init__(self, row_linear: torch.nn.Module, mesh: Mesh,
enable_sequence_parallelism: bool):
super().__init__()
assert isinstance(row_linear, RowParallelLinear)

self.mesh = mesh
self.reduce_results = row_linear.reduce_results
self.skip_bias_add = row_linear.skip_bias_add
self.return_bias = row_linear.return_bias
self.enable_sequence_parallelism = enable_sequence_parallelism

self.w8q8_int8_quant = False
if isinstance(row_linear.quant_method,
Expand Down Expand Up @@ -96,4 +99,10 @@ def forward(self, input: torch.Tensor):
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
if self.enable_sequence_parallelism:
token_num = input.shape[0]
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
if token_num // self.mesh.shape[
'model'] >= TPU_SECOND_LAST_MINOR:
output.shard_(NamedSharding(self.mesh, P('model', None)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't sequence parallelism usually use 'data' to shard batch dim? Meaning, we can use both sequence and model parallelism and shard the inputs/outputs using P('data', 'model')

Copy link
Collaborator Author

@yaochengji yaochengji Aug 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, only the mesh "model" axis is enough. Here the sequence parallelism is applied on layer_norm, and model parallelism is applied on matmul. It is described in this paper: https://arxiv.org/abs/2205.05198

return output, output_bias
20 changes: 16 additions & 4 deletions tpu_commons/models/vllm/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,24 @@ def shard_attention(layer: torch.nn.Module, mesh: Mesh,
def shard_qkv_parallel_linear(layer: torch.nn.Module, mesh: Mesh,
vllm_config: VllmConfig):
assert isinstance(layer, QKVParallelLinear)
jax_layer = JaxQKVParallelLinear(layer, mesh,
shard_qkv_parallel_linear.fuse_matmuls)
jax_layer = JaxQKVParallelLinear(
layer,
mesh,
shard_qkv_parallel_linear.fuse_matmuls,
enable_sequence_parallelism=vllm_config.compilation_config.pass_config.
enable_sequence_parallelism)
return jax_layer


def shard_merged_column_parallel_linear(layer: torch.nn.Module, mesh: Mesh,
vllm_config: VllmConfig):
assert isinstance(layer, MergedColumnParallelLinear)
jax_layer = JaxMergedColumnParallelLinear(
layer, mesh, shard_merged_column_parallel_linear.fuse_matmuls)
layer,
mesh,
shard_merged_column_parallel_linear.fuse_matmuls,
enable_sequence_parallelism=vllm_config.compilation_config.pass_config.
enable_sequence_parallelism)
return jax_layer


Expand All @@ -73,7 +81,11 @@ def shard_column_parallel_linear(layer: torch.nn.Module, mesh: Mesh,
def shard_row_parallel_linear(layer: torch.nn.Module, mesh: Mesh,
vllm_config: VllmConfig):
assert isinstance(layer, RowParallelLinear)
jax_layer = JaxRowParallelLinear(layer, mesh)
jax_layer = JaxRowParallelLinear(
layer,
mesh,
enable_sequence_parallelism=vllm_config.compilation_config.pass_config.
enable_sequence_parallelism)
return jax_layer


Expand Down
6 changes: 6 additions & 0 deletions tpu_commons/models/vllm/vllm_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ def jit_step_func(self):
@functools.partial(
jax.jit,
donate_argnums=(1, ), # donate kv_cache
compiler_options={
"xla_tpu_all_gather_collective_matmul_mode":
"post_spmd_conservative",
"xla_tpu_reduce_scatter_collective_matmul_mode":
"post_spmd_conservative"
},
)
def step_fun(
params_and_buffers, # this has been wrapped into a torchax TorchValue
Expand Down
7 changes: 4 additions & 3 deletions tpu_commons/runner/jax/tpu_jax_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from tpu_commons.runner.jax.input_batch_jax import (CachedRequestState,
InputBatch)
from tpu_commons.runner.jax.metadata import SpecDecodeMetadata
from tpu_commons.utils import make_optimized_mesh

logger = init_logger(__name__)

Expand Down Expand Up @@ -156,9 +157,9 @@ def _init_mesh(self) -> None:
axis_names = ("data", "model")
mesh_shape = (dp, tp)

self.mesh = jax.make_mesh(mesh_shape,
axis_names,
devices=self.devices)
self.mesh = make_optimized_mesh(mesh_shape,
axis_names,
devices=self.devices)
logger.info(f"Init mesh | mesh={self.mesh}")

def _init_inputs(self) -> None:
Expand Down
65 changes: 65 additions & 0 deletions tpu_commons/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
import os
from collections import defaultdict
from collections.abc import Sequence
from typing import Any, List, Tuple

import jax
import numpy as np
from jax._src import dtypes
from jax._src import mesh as mesh_lib
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from vllm import envs

from tpu_commons.logger import init_logger

GBYTES = 1024 * 1024 * 1024
TPU_HEAD_SIZE_ALIGNMENT = 128
TPU_SECOND_LAST_MINOR = 8

_megacore = False
logger = init_logger(__name__)
Expand Down Expand Up @@ -106,3 +112,62 @@ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
def get_dtype_packing(dtype):
bits = dtypes.bit_width(dtype)
return 32 // bits


def make_optimized_mesh(axis_shapes: Sequence[int],
axis_names: Sequence[str],
*,
devices: Sequence[xc.Device] | None = None):
if devices is None:
devices = xb.devices()

def _is_1D(axis_shapes):
return sum(x > 1 for x in axis_shapes) == 1

if _is_1D(axis_shapes):
dev_kind = devices[0].device_kind
device_num = len(devices)
if dev_kind == "TPU v6 lite":
ordered_devices = None
# NOTE(chengjiyao):
# The coords of v6e-8 are
# (0,0,0)
# (1,0,0)
# (0,1,0)
# (1,1,0)
# (0,2,0)
# (1,2,0)
# (0,3,0)
# (1,3,0)
if device_num == 8:
ordered_devices = np.array([
devices[0],
devices[2],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why (1,0,0) maps to device [2].

Also, does the order

            # (0,0,0)
            # (1,0,0)
            # (0,1,0)
            # (1,1,0)
            # (0,2,0)
            # (1,2,0)
            # (0,3,0)
            # (1,3,0)

matter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, from the order, it matters

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does (1,0,0) maps to device [2]?

devices[4],
devices[6],
devices[7],
devices[5],
devices[3],
devices[1],
])
# NOTE(chengjiyao):
# The coords of v6e-4 are
# (0,0,0)
# (1,0,0)
# (0,1,0)
# (1,1,0)
elif device_num == 4:
ordered_devices = np.array([
devices[0],
devices[2],
devices[3],
devices[1],
])
if ordered_devices is not None:
ordered_devices = np.array(ordered_devices)
ordered_devices = ordered_devices.reshape(axis_shapes)
mesh = mesh_lib.Mesh(ordered_devices, axis_names)
logger.info("Use customized mesh: %s", mesh)
return mesh

return jax.make_mesh(axis_shapes, axis_names, devices=devices)