-
Notifications
You must be signed in to change notification settings - Fork 38
Support sequence parallelism and collective matmul #520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| forward_unqunatized, forward_w8a8_int8_col_parallel, | ||
| reorder_concatenated_tensor_for_sharding, | ||
| slice_sharded_tensor_for_concatenation) | ||
| from tpu_commons.utils import TPU_SECOND_LAST_MINOR | ||
|
|
||
| P = PartitionSpec | ||
|
|
||
|
|
@@ -22,7 +23,7 @@ class JaxMergedColumnParallelLinearCore(torch.nn.Module): | |
| """ A common class to implement Column Parallel Linear layer whose weight are merged from a list of smaller weight tensors, e.g. vLLM's MergedColumnParallelLinear and QKVParallelLinear layer. """ | ||
|
|
||
| def __init__(self, vllm_col_par_linear: torch.nn.Module, mesh: Mesh, | ||
| name: str, fuse_matmuls: bool): | ||
| name: str, fuse_matmuls: bool, enable_sequence_parallelism): | ||
| super().__init__() | ||
|
|
||
| self.gather_output = vllm_col_par_linear.gather_output | ||
|
|
@@ -33,6 +34,7 @@ def __init__(self, vllm_col_par_linear: torch.nn.Module, mesh: Mesh, | |
| self.name = name | ||
| self.fuse_matmuls = fuse_matmuls | ||
| self.has_bias = vllm_col_par_linear.bias is not None | ||
| self.enable_sequence_parallelism = enable_sequence_parallelism | ||
| self.n_matmuls = len(self.output_sizes) | ||
| assert vllm_col_par_linear.tp_size == 1, ( | ||
| "The model has to be loaded with TP== 1 in order to run in Jax SPMD." | ||
|
|
@@ -230,7 +232,14 @@ def forward_split(self, input): | |
|
|
||
| def forward(self, input: torch.Tensor): | ||
| with jax.named_scope(self.name): | ||
| if self.enable_sequence_parallelism: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unit test probably is needed
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I intended to add an e2e test, but cannot find a good example.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, @hfan ! The unit test is added, manually triggered in https://buildkite.com/tpu-commons/tpu-commons-ci/builds/2043 |
||
| token_num = input.shape[0] | ||
| # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If "sharded token_num is larger than TPU_SECOND_LAST_MINOR", I guess the downside is waste of memory?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There will be more communication in the final result, you can have a try. |
||
| if token_num // self.mesh.shape[ | ||
| 'model'] >= TPU_SECOND_LAST_MINOR: | ||
| input.shard_(NamedSharding(self.mesh, P('model', None))) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very neat and convenient way of doing
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that SP is implemented by sharding the num_tokens dimension. Do you need to do an all-gather at the very end? I couldn't find it in your pr.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sharding propagation can handle this and make it correct.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. chatted offline, all-gather is not needed because some later ops (e.g. select which token to get logit (need to select num_reqs tokens)) may hint the compiler to get the global view. At that time, compiler will do a all-gather implicitly. |
||
| if self.fuse_matmuls: | ||
| return self.forward_fused(input) | ||
| output, output_bias = self.forward_fused(input) | ||
| else: | ||
| return self.forward_split(input) | ||
| output, output_bias = self.forward_split(input) | ||
| return output, output_bias | ||
hfan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,20 +15,23 @@ | |
|
|
||
| from tpu_commons.models.vllm.jax_linear_common import ( | ||
| forward_unqunatized, forward_w8a8_int8_row_parallel) | ||
| from tpu_commons.utils import TPU_SECOND_LAST_MINOR | ||
|
|
||
| P = PartitionSpec | ||
|
|
||
|
|
||
| class JaxRowParallelLinear(torch.nn.Module): | ||
|
|
||
| def __init__(self, row_linear: torch.nn.Module, mesh: Mesh): | ||
| def __init__(self, row_linear: torch.nn.Module, mesh: Mesh, | ||
| enable_sequence_parallelism: bool): | ||
| super().__init__() | ||
| assert isinstance(row_linear, RowParallelLinear) | ||
|
|
||
| self.mesh = mesh | ||
| self.reduce_results = row_linear.reduce_results | ||
| self.skip_bias_add = row_linear.skip_bias_add | ||
| self.return_bias = row_linear.return_bias | ||
| self.enable_sequence_parallelism = enable_sequence_parallelism | ||
|
|
||
| self.w8q8_int8_quant = False | ||
| if isinstance(row_linear.quant_method, | ||
|
|
@@ -96,4 +99,10 @@ def forward(self, input: torch.Tensor): | |
| if not self.return_bias: | ||
| return output | ||
| output_bias = self.bias if self.skip_bias_add else None | ||
| if self.enable_sequence_parallelism: | ||
| token_num = input.shape[0] | ||
| # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR | ||
| if token_num // self.mesh.shape[ | ||
| 'model'] >= TPU_SECOND_LAST_MINOR: | ||
| output.shard_(NamedSharding(self.mesh, P('model', None))) | ||
hfan marked this conversation as resolved.
Show resolved
Hide resolved
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't sequence parallelism usually use
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case, only the mesh "model" axis is enough. Here the sequence parallelism is applied on layer_norm, and model parallelism is applied on matmul. It is described in this paper: https://arxiv.org/abs/2205.05198 |
||
| return output, output_bias | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,22 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import os | ||
| from collections import defaultdict | ||
| from collections.abc import Sequence | ||
| from typing import Any, List, Tuple | ||
|
|
||
| import jax | ||
| import numpy as np | ||
| from jax._src import dtypes | ||
| from jax._src import mesh as mesh_lib | ||
| from jax._src import xla_bridge as xb | ||
| from jax._src.lib import xla_client as xc | ||
| from vllm import envs | ||
|
|
||
| from tpu_commons.logger import init_logger | ||
|
|
||
| GBYTES = 1024 * 1024 * 1024 | ||
| TPU_HEAD_SIZE_ALIGNMENT = 128 | ||
| TPU_SECOND_LAST_MINOR = 8 | ||
|
|
||
| _megacore = False | ||
| logger = init_logger(__name__) | ||
|
|
@@ -106,3 +112,62 @@ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int: | |
| def get_dtype_packing(dtype): | ||
| bits = dtypes.bit_width(dtype) | ||
| return 32 // bits | ||
|
|
||
|
|
||
| def make_optimized_mesh(axis_shapes: Sequence[int], | ||
| axis_names: Sequence[str], | ||
| *, | ||
| devices: Sequence[xc.Device] | None = None): | ||
| if devices is None: | ||
| devices = xb.devices() | ||
|
|
||
| def _is_1D(axis_shapes): | ||
| return sum(x > 1 for x in axis_shapes) == 1 | ||
|
|
||
| if _is_1D(axis_shapes): | ||
| dev_kind = devices[0].device_kind | ||
| device_num = len(devices) | ||
| if dev_kind == "TPU v6 lite": | ||
| ordered_devices = None | ||
| # NOTE(chengjiyao): | ||
| # The coords of v6e-8 are | ||
| # (0,0,0) | ||
| # (1,0,0) | ||
| # (0,1,0) | ||
| # (1,1,0) | ||
| # (0,2,0) | ||
| # (1,2,0) | ||
| # (0,3,0) | ||
| # (1,3,0) | ||
| if device_num == 8: | ||
| ordered_devices = np.array([ | ||
| devices[0], | ||
| devices[2], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder why Also, does the order matter?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, from the order, it matters
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does |
||
| devices[4], | ||
| devices[6], | ||
| devices[7], | ||
| devices[5], | ||
| devices[3], | ||
| devices[1], | ||
| ]) | ||
| # NOTE(chengjiyao): | ||
| # The coords of v6e-4 are | ||
| # (0,0,0) | ||
| # (1,0,0) | ||
| # (0,1,0) | ||
| # (1,1,0) | ||
| elif device_num == 4: | ||
| ordered_devices = np.array([ | ||
| devices[0], | ||
| devices[2], | ||
| devices[3], | ||
| devices[1], | ||
| ]) | ||
| if ordered_devices is not None: | ||
| ordered_devices = np.array(ordered_devices) | ||
| ordered_devices = ordered_devices.reshape(axis_shapes) | ||
| mesh = mesh_lib.Mesh(ordered_devices, axis_names) | ||
| logger.info("Use customized mesh: %s", mesh) | ||
| return mesh | ||
|
|
||
| return jax.make_mesh(axis_shapes, axis_names, devices=devices) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it using one TPU device? Do you need to test when there are multiple TPU device?