From b5637304a8462ea1f8a50bff8c942afb28288baa Mon Sep 17 00:00:00 2001 From: Kyuyeun Kim Date: Sat, 16 Aug 2025 09:43:00 +0000 Subject: [PATCH] Refactor torchax layers to use vLLM APIs Signed-off-by: Kyuyeun Kim --- .../test_compressed_tensors_w8a8_int8.py | 404 ++++++++++++++ .../vllm/{ => layers}/test_jax_fused_moe.py | 0 .../models/vllm/layers/test_pallas_torchax.py | 236 ++++++++ tests/models/vllm/layers/test_unquantized.py | 392 ++++++++++++++ tests/models/vllm/layers/utils.py | 19 + tests/models/vllm/test_jax_attention.py | 153 ------ .../test_jax_merged_column_parallel_linear.py | 149 ----- .../vllm/test_jax_qkv_parallel_linear.py | 145 ----- .../vllm/test_jax_row_parallel_linear.py | 156 ------ tests/models/vllm/test_pallas_torchax.py | 508 ------------------ tests/models/vllm/utils.py | 53 -- .../attention/backends/pallas_torchax.py | 287 +++------- tpu_commons/models/vllm/jax_attention.py | 104 ---- tpu_commons/models/vllm/jax_linear_common.py | 105 ++-- .../vllm/jax_merged_column_parallel_linear.py | 20 - .../jax_merged_column_parallel_linear_core.py | 248 --------- .../models/vllm/jax_qkv_parallel_linear.py | 19 - .../models/vllm/jax_row_parallel_linear.py | 108 ---- .../models/vllm/quantization/__init__.py | 32 ++ .../models/vllm/quantization/common.py | 91 ++++ .../compressed_tensors/compressed_tensors.py | 109 ++++ .../schemes/compressed_tensors_w8a8_int8.py | 136 +++++ .../models/vllm/quantization/unquantized.py | 146 +++++ tpu_commons/models/vllm/sharding.py | 146 ++--- tpu_commons/models/vllm/vllm_model_wrapper.py | 19 +- .../models/vllm/vllm_model_wrapper_context.py | 13 +- tpu_commons/platforms/tpu_jax.py | 2 +- 27 files changed, 1782 insertions(+), 2018 deletions(-) create mode 100644 tests/models/vllm/layers/test_compressed_tensors_w8a8_int8.py rename tests/models/vllm/{ => layers}/test_jax_fused_moe.py (100%) create mode 100644 tests/models/vllm/layers/test_pallas_torchax.py create mode 100644 tests/models/vllm/layers/test_unquantized.py create mode 100644 tests/models/vllm/layers/utils.py delete mode 100644 tests/models/vllm/test_jax_attention.py delete mode 100644 tests/models/vllm/test_jax_merged_column_parallel_linear.py delete mode 100644 tests/models/vllm/test_jax_qkv_parallel_linear.py delete mode 100644 tests/models/vllm/test_jax_row_parallel_linear.py delete mode 100644 tests/models/vllm/test_pallas_torchax.py delete mode 100644 tests/models/vllm/utils.py delete mode 100644 tpu_commons/models/vllm/jax_attention.py delete mode 100644 tpu_commons/models/vllm/jax_merged_column_parallel_linear.py delete mode 100644 tpu_commons/models/vllm/jax_merged_column_parallel_linear_core.py delete mode 100644 tpu_commons/models/vllm/jax_qkv_parallel_linear.py delete mode 100644 tpu_commons/models/vllm/jax_row_parallel_linear.py create mode 100644 tpu_commons/models/vllm/quantization/__init__.py create mode 100644 tpu_commons/models/vllm/quantization/common.py create mode 100644 tpu_commons/models/vllm/quantization/compressed_tensors/compressed_tensors.py create mode 100644 tpu_commons/models/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py create mode 100644 tpu_commons/models/vllm/quantization/unquantized.py diff --git a/tests/models/vllm/layers/test_compressed_tensors_w8a8_int8.py b/tests/models/vllm/layers/test_compressed_tensors_w8a8_int8.py new file mode 100644 index 000000000..787bf413f --- /dev/null +++ b/tests/models/vllm/layers/test_compressed_tensors_w8a8_int8.py @@ -0,0 +1,404 @@ +import tempfile +from typing import Optional + +import jax +import pytest +import torch +import torchax +import utils as test_utils +from jax.sharding import NamedSharding, PartitionSpec +from torchax.interop import torch_view +from torchax.ops.mappings import j2t, t2j +from vllm.config import set_current_vllm_config +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \ + CompressedTensorsLinearMethod +from vllm.model_executor.model_loader import get_model as vllm_get_model + +from tpu_commons.models.vllm.quantization import get_tpu_quantization_config +from tpu_commons.models.vllm.quantization.compressed_tensors.compressed_tensors import \ + JaxCompressedTensorsConfig +from tpu_commons.models.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \ + JaxCompressedTensorsW8A8Int8 + +P = PartitionSpec + + +def ref_quantize_int8(x: torch.Tensor): + x_abs_max = torch.amax(torch.abs(x), dim=1, keepdim=True) + x_s = x_abs_max / 127 + x_q = torch.round(x / x_s).to(torch.int8) + return x_q, x_s.to(torch.float32) + + +def ref_w8a8_int8(x: torch.Tensor, w_q: torch.Tensor, w_s: torch.Tensor, + b: Optional[torch.Tensor]): + x_q, x_s = ref_quantize_int8(x) + out = torch.einsum('bd,fd->bf', x_q.to(torch.float32), + w_q.to(torch.float32)) + out = (out * x_s) * w_s.T + if b is not None: + out += b + return out.to(x.dtype) + + +@pytest.fixture(autouse=True) +def setup_environment(): + # This is a fake config used for init dist env. + # RowParallelLinear needs dist env to be initialized. + engine_args = EngineArgs( + model="RedHatAI/Qwen2.5-1.5B-quantized.w8a8", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + + vllm_config = engine_args.create_engine_config() + + with set_current_vllm_config(vllm_config): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + 1, + 0, + local_rank=0, + distributed_init_method=f"file://{temp_file}", + backend="gloo") + ensure_model_parallel_initialized(1, 1) + + +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +def test_quant_override(mesh): + + engine_args = EngineArgs( + model="RedHatAI/Qwen2.5-1.5B-quantized.w8a8", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.model_config.dtype = torch.bfloat16 + + quant_config = get_tpu_quantization_config(vllm_config, mesh) + assert isinstance(quant_config, JaxCompressedTensorsConfig) + assert quant_config.vllm_config == vllm_config + assert quant_config.mesh == mesh + + +@pytest.mark.parametrize("model", ["RedHatAI/Qwen2.5-1.5B-quantized.w8a8"]) +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +def test_loading_model(model, mesh): + engine_args = EngineArgs( + model=model, + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.model_config.dtype = torch.bfloat16 + vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh) + vllm_config.device_config.device = "cpu" + + vllm_model = vllm_get_model(vllm_config=vllm_config) + layers = test_utils.find_all_layer_type(vllm_model, LinearBase) + for layer in layers: + assert isinstance(layer.quant_config, JaxCompressedTensorsConfig) + assert isinstance(layer.quant_method, CompressedTensorsLinearMethod) + assert isinstance(layer.scheme, JaxCompressedTensorsW8A8Int8) + + +@pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +@pytest.mark.parametrize("enable_sp", [False, True]) +def test_jax_row_parallel_linear(bias, mesh, enable_sp): + dtype = torch.bfloat16 + + engine_args = EngineArgs( + model="RedHatAI/Qwen2.5-1.5B-quantized.w8a8", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + + # Call tpu_commons code + vllm_config.model_config.dtype = dtype + quant_config = get_tpu_quantization_config(vllm_config, mesh) + with set_current_vllm_config(vllm_config): + jax_row_linear = RowParallelLinear( + input_size=4096, + output_size=8192, + bias=bias, + params_dtype=dtype, + return_bias=False, + quant_config=quant_config, + ) + + weight_data_float = torch.rand( + (jax_row_linear.output_size, jax_row_linear.input_size), + dtype=dtype) / 10 + weight_data, weight_scale_data = ref_quantize_int8(weight_data_float) + if bias: + bias_data = torch.rand_like(jax_row_linear.bias.data) + + jax_row_linear.weight.data = weight_data + jax_row_linear.weight_scale.data = weight_scale_data + if bias: + jax_row_linear.bias.data = bias_data + + input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 + input_tensor = input_tensor.to('cpu') + + jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False)) + jax_input_tensor.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + with torchax.default_env(): + assert isinstance(jax_row_linear.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(jax_row_linear.scheme, JaxCompressedTensorsW8A8Int8) + jax_row_linear.quant_method.process_weights_after_loading( + jax_row_linear) + jax_output = jax_row_linear(jax_input_tensor) + jax_output = j2t(jax_output.to(torch.float32)).to(dtype) + + # Call reference w8a8 int8 + output = ref_w8a8_int8( + input_tensor, + weight_data, + weight_scale_data, + bias_data if bias else None, + ) + + torch.testing.assert_close(output, jax_output) + + +@pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +@pytest.mark.parametrize("enable_sp", [False, True]) +def test_jax_column_parallel_linear(bias, mesh, enable_sp): + dtype = torch.bfloat16 + + engine_args = EngineArgs( + model="RedHatAI/Qwen2.5-1.5B-quantized.w8a8", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + + # Call tpu_commons code + vllm_config.model_config.dtype = torch.bfloat16 + quant_config = get_tpu_quantization_config(vllm_config, mesh) + with set_current_vllm_config(vllm_config): + jax_column_linear = ColumnParallelLinear( + input_size=4096, + output_size=8192, + bias=bias, + params_dtype=dtype, + return_bias=False, + quant_config=quant_config, + ) + + weight_data_float = torch.rand( + (jax_column_linear.output_size, jax_column_linear.input_size), + dtype=dtype) / 10 + weight_data, weight_scale_data = ref_quantize_int8(weight_data_float) + if bias: + bias_data = torch.rand_like(jax_column_linear.bias.data) + + jax_column_linear.weight.data = weight_data + jax_column_linear.weight_scale.data = weight_scale_data + if bias: + jax_column_linear.bias.data = bias_data + + input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 + input_tensor = input_tensor.to('cpu') + + jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False)) + jax_input_tensor.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + with torchax.default_env(): + assert isinstance(jax_column_linear.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(jax_column_linear.scheme, + JaxCompressedTensorsW8A8Int8) + jax_column_linear.quant_method.process_weights_after_loading( + jax_column_linear) + jax_output = jax_column_linear(jax_input_tensor) + jax_output = j2t(jax_output.to(torch.float32)).to(dtype) + + # Call reference w8a8 int8 + output = ref_w8a8_int8( + input_tensor, + weight_data, + weight_scale_data, + bias_data if bias else None, + ) + + torch.testing.assert_close(output, jax_output) + + +@pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +@pytest.mark.parametrize("enable_sp", [False, True]) +@pytest.mark.parametrize("fuse_matmuls", [False, True]) +def test_jax_qkv_parallel_linear(bias, mesh, enable_sp, fuse_matmuls): + dtype = torch.bfloat16 + + engine_args = EngineArgs( + model="RedHatAI/Qwen2.5-1.5B-quantized.w8a8", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + + # Call tpu_commons code + vllm_config.model_config.dtype = torch.bfloat16 + quant_config = get_tpu_quantization_config(vllm_config, mesh) + with set_current_vllm_config(vllm_config): + jax_qkv_linear = QKVParallelLinear( + hidden_size=4096, + head_size=128, + total_num_heads=32, + total_num_kv_heads=8, + bias=bias, + params_dtype=dtype, + return_bias=False, + quant_config=quant_config, + ) + jax_qkv_linear.quant_method.fuse_matmuls = fuse_matmuls + + weight_data_float = torch.rand( + (jax_qkv_linear.output_size, jax_qkv_linear.input_size), + dtype=dtype) / 10 + weight_data, weight_scale_data = ref_quantize_int8(weight_data_float) + if bias: + bias_data = torch.rand_like(jax_qkv_linear.bias.data) + + jax_qkv_linear.weight.data = weight_data + jax_qkv_linear.weight_scale.data = weight_scale_data + if bias: + jax_qkv_linear.bias.data = bias_data + + input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 + input_tensor = input_tensor.to('cpu') + + jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False)) + jax_input_tensor.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + with torchax.default_env(): + assert isinstance(jax_qkv_linear.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(jax_qkv_linear.scheme, JaxCompressedTensorsW8A8Int8) + jax_qkv_linear.quant_method.process_weights_after_loading( + jax_qkv_linear) + jax_output = jax_qkv_linear(jax_input_tensor) + jax_output = j2t(jax_output.to(torch.float32)).to(dtype) + + # Call reference w8a8 int8 + output = ref_w8a8_int8( + input_tensor, + weight_data, + weight_scale_data, + bias_data if bias else None, + ) + + torch.testing.assert_close(output, jax_output) + + +@pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +@pytest.mark.parametrize("fuse_matmuls", [False, True]) +@pytest.mark.parametrize("enable_sp", [False, True]) +def test_jax_merged_column_parallel_linear(bias, mesh, fuse_matmuls, + enable_sp): + dtype = torch.bfloat16 + + engine_args = EngineArgs( + model="RedHatAI/Qwen2.5-1.5B-quantized.w8a8", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + + # Call tpu_commons code + vllm_config.model_config.dtype = torch.bfloat16 + quant_config = get_tpu_quantization_config(vllm_config, mesh) + with set_current_vllm_config(vllm_config): + jax_merged_column_linear = MergedColumnParallelLinear( + input_size=4096, + output_sizes=[14336] * 2, + bias=bias, + params_dtype=dtype, + return_bias=False, + quant_config=quant_config, + ) + jax_merged_column_linear.quant_method.fuse_matmuls = fuse_matmuls + + weight_data_float = torch.rand((jax_merged_column_linear.output_size, + jax_merged_column_linear.input_size), + dtype=dtype) / 10 + weight_data, weight_scale_data = ref_quantize_int8(weight_data_float) + if bias: + bias_data = torch.rand_like(jax_merged_column_linear.bias.data) + + jax_merged_column_linear.weight.data = weight_data + jax_merged_column_linear.weight_scale.data = weight_scale_data + if bias: + jax_merged_column_linear.bias.data = bias_data + + input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 + input_tensor = input_tensor.to('cpu') + + jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False)) + jax_input_tensor.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + with torchax.default_env(): + assert isinstance(jax_merged_column_linear.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(jax_merged_column_linear.scheme, + JaxCompressedTensorsW8A8Int8) + jax_merged_column_linear.quant_method.process_weights_after_loading( + jax_merged_column_linear) + jax_output = jax_merged_column_linear(jax_input_tensor) + jax_output = j2t(jax_output.to(torch.float32)).to(dtype) + + # Call reference w8a8 int8 + output = ref_w8a8_int8( + input_tensor, + weight_data, + weight_scale_data, + bias_data if bias else None, + ) + + torch.testing.assert_close(output, jax_output) diff --git a/tests/models/vllm/test_jax_fused_moe.py b/tests/models/vllm/layers/test_jax_fused_moe.py similarity index 100% rename from tests/models/vllm/test_jax_fused_moe.py rename to tests/models/vllm/layers/test_jax_fused_moe.py diff --git a/tests/models/vllm/layers/test_pallas_torchax.py b/tests/models/vllm/layers/test_pallas_torchax.py new file mode 100644 index 000000000..273c0581e --- /dev/null +++ b/tests/models/vllm/layers/test_pallas_torchax.py @@ -0,0 +1,236 @@ +from unittest.mock import MagicMock + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +import torch +import torchax +from jax.sharding import Mesh +from torchax.interop import torch_view +from vllm.attention.backends.abstract import AttentionType + +from tpu_commons.attention.backends.pallas_torchax import ( + PallasAttentionBackend, PallasAttentionBackendImpl) +from tpu_commons.models.jax.attention import get_kv_cache_shape_with_mesh +from tpu_commons.models.jax.attention_metadata import AttentionMetadata +from tpu_commons.models.vllm.vllm_model_wrapper_context import \ + set_vllm_model_wrapper_context + +# ---- Test Configuration & Constants ---- + +# Total number of tokens across all sequences in the batch +TOTAL_TOKENS = 10 +# Number of sequences in the batch +NUM_SEQS = 2 +# Padded maximum number of sequences +MAX_NUM_SEQS = 4 +# Number of attention heads (Query) +NUM_HEADS = 8 +# Number of attention heads (Key/Value) - for Grouped-Query Attention +NUM_KV_HEADS = 4 +# Dimension of each attention head +HEAD_DIM = 64 +# Padded head dimension +PADDED_HEAD_DIM = 64 +# Total number of blocks in the KV cache +NUM_BLOCKS = 32 +# Number of tokens per block +BLOCK_SIZE = 16 +# Maximum number of blocks a single sequence can occupy +MAX_BLOCKS_PER_SEQ = 8 + + +def create_inputs(mesh): + key = jax.random.key(0) + q = jax.random.uniform(key, (TOTAL_TOKENS, NUM_HEADS * HEAD_DIM), + dtype=jnp.bfloat16) + k = jax.random.uniform(key, (TOTAL_TOKENS, NUM_KV_HEADS * HEAD_DIM), + dtype=jnp.bfloat16) + v = jax.random.uniform(key, (TOTAL_TOKENS, NUM_KV_HEADS * HEAD_DIM), + dtype=jnp.bfloat16) + q = torch_view(q) + k = torch_view(k) + v = torch_view(v) + + kv_cache_shape = get_kv_cache_shape_with_mesh(mesh, NUM_BLOCKS, BLOCK_SIZE, + NUM_KV_HEADS, HEAD_DIM, + jnp.bfloat16) + kv_cache = jax.random.normal(key, kv_cache_shape, dtype=jnp.bfloat16) + + positions = jnp.ones((TOTAL_TOKENS, ), dtype=jnp.int32) + block_tables = jnp.zeros((MAX_NUM_SEQS * MAX_BLOCKS_PER_SEQ), + dtype=jnp.int32).reshape(-1) + seq_lens = jnp.array([5, 5, 0, 0], dtype=jnp.int32) + query_start_loc = jnp.array([0, 5, 10, 10, 10], dtype=jnp.int32) + request_distribution = jnp.array([0, 0, NUM_SEQS], dtype=jnp.int32) + + metadata = AttentionMetadata( + input_positions=positions, + block_tables=block_tables, + seq_lens=seq_lens, + query_start_loc=query_start_loc, + request_distribution=request_distribution, + ) + + return q, k, v, kv_cache, metadata + + +@pytest.fixture +def mesh(): + """Provides a mock 1D JAX mesh for testing.""" + # Create a mesh with available devices, useful for running on CPU/GPU/TPU + # For this test, it will likely be a single CPU device. + devices = np.array(jax.local_devices()) + if not devices.any(): + # Add a mock device if no devices are present (e.g., in a CI environment) + devices = np.array([jax.devices("cpu")[0]]) + return Mesh(devices.reshape((-1, 1)), ("data", "model")) + + +class TestPallasAttentionBackend: + + def test_get_name(self): + assert PallasAttentionBackend.get_name() == "PALLAS_VLLM_V1" + + def test_get_impl_cls(self): + assert PallasAttentionBackend.get_impl_cls( + ) == PallasAttentionBackendImpl + + +class TestPallasAttentionBackendImpl: + + def test_init_valid_params(self): + impl = PallasAttentionBackendImpl( + num_heads=32, + head_size=128, + scale=0.088, + num_kv_heads=8, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + attn_type=AttentionType.DECODER, + ) + + assert impl.num_heads == 32 + assert impl.head_size == 128 + assert impl.scale == 0.088 + assert impl.num_kv_heads == 8 + assert impl.num_queries_per_kv == 4 + assert impl.sliding_window is None + + def test_init_with_alibi_slopes_raises_error(self): + with pytest.raises(NotImplementedError, + match="Alibi slopes is not supported"): + PallasAttentionBackendImpl( + num_heads=32, + head_size=128, + scale=0.088, + num_kv_heads=8, + alibi_slopes=[1.0, 2.0], + sliding_window=None, + kv_cache_dtype="auto", + attn_type=AttentionType.DECODER, + ) + + def test_init_with_fp8_kv_cache_raises_error(self): + with pytest.raises(NotImplementedError, + match="FP8 KV cache dtype is not supported"): + PallasAttentionBackendImpl( + num_heads=32, + head_size=128, + scale=0.088, + num_kv_heads=8, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="fp8", + attn_type=AttentionType.DECODER, + ) + + def test_init_with_encoder_attention_raises_error(self): + with pytest.raises(NotImplementedError, + match="Encoder self-attention"): + PallasAttentionBackendImpl( + num_heads=32, + head_size=128, + scale=0.088, + num_kv_heads=8, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + attn_type=AttentionType.ENCODER, + ) + + def test_forward(self, mesh): + impl = PallasAttentionBackendImpl( + num_heads=NUM_HEADS, + head_size=HEAD_DIM, + scale=0.088, + num_kv_heads=NUM_KV_HEADS, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + attn_type=AttentionType.DECODER, + ) + + layer = MagicMock() + layer.layer_name = "0" + + query, key, value, kv_cache, metadata = create_inputs(mesh) + + with torchax.default_env(), set_vllm_model_wrapper_context( + kv_caches=[kv_cache], mesh=mesh): + impl.forward(layer, query, key, value, torch.tensor([]), metadata) + + def test_forward_with_vllm_kv_cache_raises_error(self, mesh): + impl = PallasAttentionBackendImpl( + num_heads=NUM_HEADS, + head_size=HEAD_DIM, + scale=0.088, + num_kv_heads=NUM_KV_HEADS, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + attn_type=AttentionType.DECODER, + ) + + layer = MagicMock() + layer.layer_name = "0" + + query, key, value, kv_cache, metadata = create_inputs(mesh) + + with torchax.default_env(), set_vllm_model_wrapper_context( + kv_caches=[kv_cache], + mesh=mesh), pytest.raises(RuntimeError, + match="should be empty but has"): + impl.forward(layer, query, key, value, torch.tensor([1]), metadata) + + def test_forward_with_output_scale_raises_error(self, mesh): + impl = PallasAttentionBackendImpl( + num_heads=NUM_HEADS, + head_size=HEAD_DIM, + scale=0.088, + num_kv_heads=NUM_KV_HEADS, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + attn_type=AttentionType.DECODER, + ) + + layer = MagicMock() + layer.layer_name = "0" + + query, key, value, kv_cache, metadata = create_inputs(mesh) + output_scale = torch.tensor([1.0]) + + with torchax.default_env(), set_vllm_model_wrapper_context( + kv_caches=[kv_cache], + mesh=mesh), pytest.raises(NotImplementedError, + match="fused output quantization"): + impl.forward(layer, + query, + key, + value, + torch.tensor([]), + metadata, + output_scale=output_scale) diff --git a/tests/models/vllm/layers/test_unquantized.py b/tests/models/vllm/layers/test_unquantized.py new file mode 100644 index 000000000..592cba169 --- /dev/null +++ b/tests/models/vllm/layers/test_unquantized.py @@ -0,0 +1,392 @@ +import tempfile + +import jax +import pytest +import torch +import torchax +import utils as test_utils +from jax.sharding import NamedSharding, PartitionSpec +from torchax.interop import torch_view +from torchax.ops.mappings import j2t, t2j +from vllm.config import set_current_vllm_config +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.model_loader import get_model as vllm_get_model + +from tpu_commons.models.vllm.quantization import get_tpu_quantization_config +from tpu_commons.models.vllm.quantization.unquantized import ( + JaxUnquantizedConfig, JaxUnquantizedLinearMethod) + +P = PartitionSpec + + +@pytest.fixture(autouse=True) +def setup_environment(): + # This is a fake config used for init dist env. + # RowParallelLinear needs dist env to be initialized. + engine_args = EngineArgs( + model="Qwen/Qwen2-1.5B-Instruct", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + + vllm_config = engine_args.create_engine_config() + + with set_current_vllm_config(vllm_config): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + 1, + 0, + local_rank=0, + distributed_init_method=f"file://{temp_file}", + backend="gloo") + ensure_model_parallel_initialized(1, 1) + + +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +def test_quant_override(mesh): + + engine_args = EngineArgs( + model="Qwen/Qwen2-1.5B-Instruct", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.model_config.dtype = torch.bfloat16 + + quant_config = get_tpu_quantization_config(vllm_config, mesh) + assert isinstance(quant_config, JaxUnquantizedConfig) + assert quant_config.vllm_config == vllm_config + assert quant_config.mesh == mesh + + +@pytest.mark.parametrize("model", ["Qwen/Qwen2-1.5B-Instruct"]) +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +def test_loading_model(model, mesh): + engine_args = EngineArgs( + model=model, + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.model_config.dtype = torch.bfloat16 + vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh) + vllm_config.device_config.device = "cpu" + + vllm_model = vllm_get_model(vllm_config=vllm_config) + layers = test_utils.find_all_layer_type(vllm_model, LinearBase) + for layer in layers: + assert isinstance(layer.quant_config, JaxUnquantizedConfig) + assert isinstance(layer.quant_method, JaxUnquantizedLinearMethod) + + +@pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +@pytest.mark.parametrize("enable_sp", [False, True]) +def test_jax_row_parallel_linear(bias, mesh, enable_sp): + dtype = torch.bfloat16 + + engine_args = EngineArgs( + model="Qwen/Qwen2-1.5B-Instruct", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + + input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 + input_tensor = input_tensor.to('cpu') + + with set_current_vllm_config(vllm_config): + row_linear = RowParallelLinear( + input_size=4096, + output_size=8192, + bias=bias, + params_dtype=dtype, + return_bias=False, + ) + + weight_data = torch.rand_like(row_linear.weight.data) / 10 + if bias: + bias_data = torch.rand_like(row_linear.bias.data) + + row_linear.weight.data = weight_data + if bias: + row_linear.bias.data = bias_data + row_linear = row_linear.to('cpu') + row_linear.quant_method.process_weights_after_loading(row_linear) + output = row_linear(input_tensor).to(dtype) + + vllm_config.model_config.dtype = dtype + quant_config = get_tpu_quantization_config(vllm_config, mesh) + with set_current_vllm_config(vllm_config): + jax_row_linear = RowParallelLinear( + input_size=4096, + output_size=8192, + bias=bias, + params_dtype=dtype, + return_bias=False, + quant_config=quant_config, + ) + + jax_row_linear.weight.data = weight_data + if bias: + jax_row_linear.bias.data = bias_data + + jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False)) + jax_input_tensor.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + with torchax.default_env(): + assert isinstance(jax_row_linear.quant_method, + JaxUnquantizedLinearMethod) + jax_row_linear.quant_method.process_weights_after_loading( + jax_row_linear) + jax_output = jax_row_linear(jax_input_tensor) + # j2t() doens't support bfloat16, so we cast it into float32 as an intermedate step. + jax_output = j2t(jax_output.to(torch.float32)).to(dtype) + + torch.testing.assert_close(output, jax_output) + + +@pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +@pytest.mark.parametrize("enable_sp", [False, True]) +def test_jax_column_parallel_linear(bias, mesh, enable_sp): + dtype = torch.bfloat16 + + engine_args = EngineArgs( + model="Qwen/Qwen2-1.5B-Instruct", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + + input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 + input_tensor = input_tensor.to('cpu') + + with set_current_vllm_config(vllm_config): + column_linear = ColumnParallelLinear( + input_size=4096, + output_size=8192, + bias=bias, + params_dtype=dtype, + return_bias=False, + ) + + weight_data = torch.rand_like(column_linear.weight.data) / 10 + if bias: + bias_data = torch.rand_like(column_linear.bias.data) + + column_linear.weight.data = weight_data + if bias: + column_linear.bias.data = bias_data + column_linear = column_linear.to('cpu') + column_linear.quant_method.process_weights_after_loading(column_linear) + output = column_linear(input_tensor).to(dtype) + + vllm_config.model_config.dtype = dtype + quant_config = get_tpu_quantization_config(vllm_config, mesh) + with set_current_vllm_config(vllm_config): + jax_column_linear = ColumnParallelLinear( + input_size=4096, + output_size=8192, + bias=bias, + params_dtype=dtype, + return_bias=False, + quant_config=quant_config, + ) + + jax_column_linear.weight.data = weight_data + if bias: + jax_column_linear.bias.data = bias_data + + jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False)) + jax_input_tensor.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + with torchax.default_env(): + assert isinstance(jax_column_linear.quant_method, + JaxUnquantizedLinearMethod) + jax_column_linear.quant_method.process_weights_after_loading( + jax_column_linear) + jax_output = jax_column_linear(jax_input_tensor) + jax_output = j2t(jax_output.to(torch.float32)).to(dtype) + + torch.testing.assert_close(output, jax_output) + + +@pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +@pytest.mark.parametrize("enable_sp", [False, True]) +@pytest.mark.parametrize("fuse_matmuls", [False, True]) +def test_jax_qkv_parallel_linear(bias, mesh, enable_sp, fuse_matmuls): + dtype = torch.bfloat16 + + engine_args = EngineArgs( + model="Qwen/Qwen2-1.5B-Instruct", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + + input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 + input_tensor = input_tensor.to('cpu') + + with set_current_vllm_config(vllm_config): + qkv_linear = QKVParallelLinear( + hidden_size=4096, + head_size=128, + total_num_heads=32, + total_num_kv_heads=8, + bias=bias, + params_dtype=dtype, + return_bias=False, + ) + + weight_data = torch.rand_like(qkv_linear.weight.data) / 10 + if bias: + bias_data = torch.rand_like(qkv_linear.bias.data) + + qkv_linear.weight.data = weight_data + if bias: + qkv_linear.bias.data = bias_data + qkv_linear = qkv_linear.to('cpu') + qkv_linear.quant_method.process_weights_after_loading(qkv_linear) + output = qkv_linear(input_tensor).to(dtype) + + vllm_config.model_config.dtype = dtype + quant_config = get_tpu_quantization_config(vllm_config, mesh) + with set_current_vllm_config(vllm_config): + jax_qkv_linear = QKVParallelLinear( + hidden_size=4096, + head_size=128, + total_num_heads=32, + total_num_kv_heads=8, + bias=bias, + params_dtype=dtype, + return_bias=False, + quant_config=quant_config, + ) + jax_qkv_linear.quant_method.fuse_matmuls = fuse_matmuls + + jax_qkv_linear.weight.data = weight_data + if bias: + jax_qkv_linear.bias.data = bias_data + + jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False)) + jax_input_tensor.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + with torchax.default_env(): + assert isinstance(jax_qkv_linear.quant_method, + JaxUnquantizedLinearMethod) + jax_qkv_linear.quant_method.process_weights_after_loading( + jax_qkv_linear) + jax_output = jax_qkv_linear(jax_input_tensor) + jax_output = j2t(jax_output.to(torch.float32)).to(dtype) + + torch.testing.assert_close(output, jax_output) + + +@pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("mesh", [ + test_utils.get_spmd_mesh(1), + test_utils.get_spmd_mesh(jax.local_device_count()) +]) +@pytest.mark.parametrize("fuse_matmuls", [False, True]) +@pytest.mark.parametrize("enable_sp", [False, True]) +def test_jax_merged_column_parallel_linear(bias, mesh, fuse_matmuls, + enable_sp): + dtype = torch.bfloat16 + + engine_args = EngineArgs( + model="Qwen/Qwen2-1.5B-Instruct", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + + input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 + input_tensor = input_tensor.to('cpu') + + # Call vLLM code + with set_current_vllm_config(vllm_config): + merged_column_linear = MergedColumnParallelLinear( + input_size=4096, + output_sizes=[14336] * 2, + bias=bias, + params_dtype=dtype, + return_bias=False, + ) + + weight_data = torch.rand_like(merged_column_linear.weight.data) / 10 + if bias: + bias_data = torch.rand_like(merged_column_linear.bias.data) + + merged_column_linear.weight.data = weight_data + if bias: + merged_column_linear.bias.data = bias_data + merged_column_linear = merged_column_linear.to('cpu') + merged_column_linear.quant_method.process_weights_after_loading( + merged_column_linear) + output = merged_column_linear(input_tensor).to(dtype) + + # Call tpu_commons code + vllm_config.model_config.dtype = dtype + quant_config = get_tpu_quantization_config(vllm_config, mesh) + with set_current_vllm_config(vllm_config): + jax_merged_column_linear = MergedColumnParallelLinear( + input_size=4096, + output_sizes=[14336] * 2, + bias=bias, + params_dtype=dtype, + return_bias=False, + quant_config=quant_config, + ) + jax_merged_column_linear.quant_method.fuse_matmuls = fuse_matmuls + + jax_merged_column_linear.weight.data = weight_data + if bias: + jax_merged_column_linear.bias.data = bias_data + + jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False)) + jax_input_tensor.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + with torchax.default_env(): + assert isinstance(jax_merged_column_linear.quant_method, + JaxUnquantizedLinearMethod) + jax_merged_column_linear.quant_method.process_weights_after_loading( + jax_merged_column_linear) + jax_output = jax_merged_column_linear(jax_input_tensor) + jax_output = j2t(jax_output.to(torch.float32)).to(dtype) + + torch.testing.assert_close(output, jax_output) diff --git a/tests/models/vllm/layers/utils.py b/tests/models/vllm/layers/utils.py new file mode 100644 index 000000000..67c211495 --- /dev/null +++ b/tests/models/vllm/layers/utils.py @@ -0,0 +1,19 @@ +import jax +import torch + + +def get_spmd_mesh(num_devices: int = 1): + axis_names = ("data", "model") + devices = sorted(jax.devices(), key=lambda d: d.id)[0:num_devices] + mesh_shape = (1, len(devices)) + return jax.make_mesh(mesh_shape, axis_names, devices=devices) + + +def find_all_layer_type(module: torch.nn.Module, layer_type: torch.nn.Module): + ret = [] + for name, child in module.named_children(): + if isinstance(child, layer_type): + ret.append(child) + else: + ret.extend(find_all_layer_type(child, layer_type)) + return ret diff --git a/tests/models/vllm/test_jax_attention.py b/tests/models/vllm/test_jax_attention.py deleted file mode 100644 index c74987780..000000000 --- a/tests/models/vllm/test_jax_attention.py +++ /dev/null @@ -1,153 +0,0 @@ -import jax -import jax.numpy as jnp -import pytest -import torch -import torchax -import utils as test_utils -from jax.sharding import NamedSharding, PartitionSpec -from torchax.interop import torch_view -from torchax.ops.mappings import j2t, t2j, t2j_dtype -from vllm.attention import Attention as VllmAttention -from vllm.config import set_current_vllm_config -from vllm.engine.arg_utils import EngineArgs - -from tpu_commons.kernels.ragged_paged_attention.v3.kernel import \ - ref_ragged_paged_attention -from tpu_commons.models.jax.attention import get_kv_cache_shape_with_mesh -from tpu_commons.models.jax.attention_metadata import AttentionMetadata -from tpu_commons.models.vllm.jax_attention import JaxAttention -from tpu_commons.models.vllm.vllm_model_wrapper_context import ( - get_vllm_model_wrapper_context, set_vllm_model_wrapper_context) - -P = PartitionSpec - - -def generate_attention_metadata(num_tokens, mesh) -> AttentionMetadata: - input_positions = None # not used in test, doesn't matter - block_tables = jnp.array( - [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], - [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], - dtype=jnp.int32) - seq_lens = jnp.array([num_tokens, 0, 0, 0, 0, 0, 0, 0], dtype=jnp.int32) - query_start_loc = jnp.array([0, num_tokens, 1, 1, 1, 1, 1, 1, 1], - dtype=jnp.int32) - request_distribution = jnp.array([0, 0, 1], dtype=jnp.int32) - - input_positions = jax.device_put(input_positions, - device=NamedSharding( - mesh, PartitionSpec(None))) - block_tables = jax.device_put(block_tables.reshape(-1), - device=NamedSharding(mesh, - PartitionSpec(None))) - seq_lens = jax.device_put(seq_lens, - device=NamedSharding(mesh, PartitionSpec(None))) - query_start_loc = jax.device_put(query_start_loc, - device=NamedSharding( - mesh, PartitionSpec(None))) - - attention_metadata = AttentionMetadata( - input_positions=input_positions, - block_tables=block_tables, - seq_lens=seq_lens, - query_start_loc=query_start_loc, - request_distribution=request_distribution, - ) - return attention_metadata - - -def generate_kv_caches(num_kv_heads, head_size, mesh, dtype): - cache_shape = get_kv_cache_shape_with_mesh(mesh, 1024, 16, num_kv_heads, - head_size, t2j_dtype(dtype)) - sharding = NamedSharding(mesh, PartitionSpec()) - - def _allocate(): - return jnp.empty( - shape=cache_shape, - dtype=t2j_dtype(dtype), - ) - - sharded_allocate = jax.jit(_allocate, out_shardings=sharding) - return [sharded_allocate()] - - -@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()]) -@pytest.mark.parametrize("num_heads", [8, 32]) -@pytest.mark.parametrize("head_size", [96, 128]) -@pytest.mark.parametrize("num_kv_heads", [8]) -@pytest.mark.parametrize("num_tokens", [15, 63]) -def test_jax_attention(mesh, num_heads, head_size, num_kv_heads, num_tokens): - dtype = torch.bfloat16 - - engine_args = EngineArgs( - model="Qwen/Qwen2-1.5B-Instruct", - max_model_len=64, - max_num_batched_tokens=64, - max_num_seqs=4, - ) - vllm_config = engine_args.create_engine_config() - vllm_config.model_config.dtype = dtype - with set_current_vllm_config(vllm_config): - attention = VllmAttention( - num_heads=num_heads, - head_size=head_size, - scale=float('nan'), # doesn't matter - num_kv_heads=num_kv_heads, - prefix="test_jax_attention.layer.0", - ) - - scale = float(1.0 / (head_size**0.5)) - qkv = torch.empty(num_tokens, - num_heads + 2 * num_kv_heads, - head_size, - dtype=dtype) - qkv.uniform_(-scale, scale) - q, k, v = qkv.split([num_heads, num_kv_heads, num_kv_heads], dim=1) - - # reshape q,k,v into vLLM convention - vllm_q = q.view(num_tokens, num_heads * head_size) - vllm_k = k.view(num_tokens, num_kv_heads * head_size) - vllm_v = v.view(num_tokens, num_kv_heads * head_size) - - # Set jax default device to workaround a layout bug in JAX 0.7.0 and earlier - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): - jax_attention = JaxAttention(attention, mesh=mesh) - vllm_q = torch_view(t2j(vllm_q)) - vllm_q.apply_jax_(jax.device_put, NamedSharding(mesh, P())) - vllm_k = torch_view(t2j(vllm_k)) - vllm_k.apply_jax_(jax.device_put, NamedSharding(mesh, P())) - vllm_v = torch_view(t2j(vllm_v)) - vllm_v.apply_jax_(jax.device_put, NamedSharding(mesh, P())) - q = t2j(q) - q = jax.device_put(q, NamedSharding(mesh, P())) - - md = generate_attention_metadata(num_tokens, mesh) - kv_caches = generate_kv_caches(num_kv_heads, head_size, mesh, dtype) - - with torchax.default_env(), set_vllm_model_wrapper_context( - kv_caches=kv_caches, - attention_metadata=md, - ): - jax_output = jax_attention(vllm_q, vllm_k, vllm_v) - - # reshape from vLLM convention - jax_output = jax_output.view(num_tokens, num_heads, head_size) - # j2t() doens't support bfloat16, so we cast it into float32 as an intermedate step. - jax_output = j2t(jax_output.to(torch.float32)).to(dtype) - - # the above jax_attetion call also updates the kv cache - vllm_model_wrapper_context = get_vllm_model_wrapper_context() - kv_cache = vllm_model_wrapper_context.kv_caches[0] - - ref_output, _ = ref_ragged_paged_attention( - q, - jax.device_put(t2j(k), NamedSharding(mesh, P())), - jax.device_put(t2j(v), NamedSharding(mesh, P())), - kv_cache, - md.seq_lens, - md.block_tables, - md.query_start_loc, - md.request_distribution, - sm_scale=scale) - ref_output = j2t(ref_output.astype(jnp.float32)).to(dtype) - - torch.testing.assert_close(ref_output, jax_output, atol=1e-2, rtol=1e-5) diff --git a/tests/models/vllm/test_jax_merged_column_parallel_linear.py b/tests/models/vllm/test_jax_merged_column_parallel_linear.py deleted file mode 100644 index 813ee3d66..000000000 --- a/tests/models/vllm/test_jax_merged_column_parallel_linear.py +++ /dev/null @@ -1,149 +0,0 @@ -import tempfile -from unittest.mock import patch - -import jax -import pytest -import torch -import torchax -import utils as test_utils -from jax.sharding import NamedSharding, PartitionSpec -from torchax.interop import torch_view -from torchax.ops.mappings import j2t, t2j -from vllm.config import set_current_vllm_config -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.engine.arg_utils import EngineArgs -from vllm.model_executor.layers.linear import MergedColumnParallelLinear -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \ - CompressedTensorsLinearMethod -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \ - CompressedTensorsW8A8Int8 - -from tpu_commons.models.vllm.jax_merged_column_parallel_linear import \ - JaxMergedColumnParallelLinear - -P = PartitionSpec - -_vllm_config = EngineArgs( - model="Qwen/Qwen2-1.5B-Instruct", - max_model_len=64, - max_num_batched_tokens=64, - max_num_seqs=4, -).create_engine_config() - - -@pytest.fixture(autouse=True) -def setup_environment(): - # This is a fake config used for init dist env. - # QKVParallelLinear needs dist env to be initialized. - with set_current_vllm_config(_vllm_config): - temp_file = tempfile.mkstemp()[1] - init_distributed_environment( - 1, - 0, - local_rank=0, - distributed_init_method=f"file://{temp_file}", - backend="gloo") - ensure_model_parallel_initialized(1, 1) - - -@pytest.mark.parametrize("bias", [False, True]) -@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()]) -@pytest.mark.parametrize("fuse_matmuls", [False, True]) -@pytest.mark.parametrize("enable_sp", [False, True]) -def test_jax_merged_column_parallel_linear(bias, mesh, fuse_matmuls, - enable_sp): - dtype = torch.bfloat16 - with set_current_vllm_config(_vllm_config): - merged_column_linear = MergedColumnParallelLinear( - input_size=4096, - output_sizes=[14336] * 2, - bias=bias, - params_dtype=dtype, - return_bias=False, - ) - merged_column_linear.weight.data = torch.rand_like( - merged_column_linear.weight.data) / 10 - if bias: - merged_column_linear.bias.data = torch.rand_like( - merged_column_linear.bias.data) - merged_column_linear = merged_column_linear.to('cpu') - merged_column_linear.quant_method.process_weights_after_loading( - merged_column_linear) - - input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 - input_tensor = input_tensor.to('cpu') - output = merged_column_linear(input_tensor).to(dtype) - - # Set jax default device to workaround a layout bug in JAX 0.7.0 and earlier - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): - jax_merged_column_linear = JaxMergedColumnParallelLinear( - merged_column_linear, - mesh, - fuse_matmuls, - enable_sequence_parallelism=enable_sp) - jax_input_tensor = torch_view(t2j(input_tensor)) - jax_input_tensor.apply_jax_(jax.device_put, - NamedSharding(mesh, P(None, None))) - with torchax.default_env(): - jax_output = jax_merged_column_linear(jax_input_tensor) - # j2t() doens't support bfloat16, so we cast it into float32 as an intermedate step. - jax_output = j2t(jax_output.to(torch.float32)).to(dtype) - - torch.testing.assert_close(output, jax_output) - - -@pytest.mark.parametrize("bias", [False, True]) -@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()]) -@pytest.mark.parametrize("fuse_matmuls", [False, True]) -@pytest.mark.parametrize("enable_sp", [False, True]) -def test_jax_merged_column_parallel_linear_w8a8_int8(bias, mesh, fuse_matmuls, - enable_sp): - dtype = torch.bfloat16 - with set_current_vllm_config(_vllm_config): - merged_column_linear = MergedColumnParallelLinear( - input_size=4096, - output_sizes=[14336] * 2, - bias=bias, - params_dtype=dtype, - return_bias=False, - quant_config=test_utils.gen_vllm_w8a8_int8_config(), - ) - - # Assert we're testing the right code path when quant config is set. - assert isinstance(merged_column_linear.quant_method, - CompressedTensorsLinearMethod) - assert isinstance(merged_column_linear.scheme, CompressedTensorsW8A8Int8) - - merged_column_linear.weight.data = torch.randint_like( - merged_column_linear.weight.data, low=-128, high=128) - merged_column_linear.weight_scale.data = torch.rand_like( - merged_column_linear.weight_scale.data) / 10 - if bias: - merged_column_linear.bias.data = torch.rand_like( - merged_column_linear.bias.data) - merged_column_linear = merged_column_linear.to('cpu') - merged_column_linear.quant_method.process_weights_after_loading( - merged_column_linear) - - input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 - input_tensor = input_tensor.to('cpu') - # Overwrite the torch_xla kernel with a reference implementation, as it's difficult to call torch_xla in tpu_commons and we want to run the ref result on CPU. - with patch( - "vllm.model_executor.layers.quantization.kernels.scaled_mm.xla.XLAScaledMMLinearKernel.apply_weights", - new=test_utils.quantized_matmul_ref): - output = merged_column_linear(input_tensor).to(dtype) - - # Set jax default device to workaround a layout bug in JAX 0.7.0 and earlier - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): - jax_merged_column_linear = JaxMergedColumnParallelLinear( - merged_column_linear, mesh, fuse_matmuls, enable_sp) - jax_input_tensor = torch_view(t2j(input_tensor)) - jax_input_tensor.apply_jax_(jax.device_put, - NamedSharding(mesh, P(None, None))) - with torchax.default_env(): - jax_output = jax_merged_column_linear(jax_input_tensor) - # j2t() doens't support bfloat16, so we cast it into float32 as an intermedate step. - jax_output = j2t(jax_output.to(torch.float32)).to(dtype) - - torch.testing.assert_close(output, jax_output, atol=5, rtol=0.1) diff --git a/tests/models/vllm/test_jax_qkv_parallel_linear.py b/tests/models/vllm/test_jax_qkv_parallel_linear.py deleted file mode 100644 index e7019ecb7..000000000 --- a/tests/models/vllm/test_jax_qkv_parallel_linear.py +++ /dev/null @@ -1,145 +0,0 @@ -import tempfile -from unittest.mock import patch - -import jax -import pytest -import torch -import torchax -import utils as test_utils -from jax.sharding import NamedSharding, PartitionSpec -from torchax.interop import torch_view -from torchax.ops.mappings import j2t, t2j -from vllm.config import set_current_vllm_config -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.engine.arg_utils import EngineArgs -from vllm.model_executor.layers.linear import QKVParallelLinear -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \ - CompressedTensorsLinearMethod -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \ - CompressedTensorsW8A8Int8 - -from tpu_commons.models.vllm.jax_qkv_parallel_linear import \ - JaxQKVParallelLinear - -P = PartitionSpec - -_vllm_config = EngineArgs( - model="Qwen/Qwen2-1.5B-Instruct", - max_model_len=64, - max_num_batched_tokens=64, - max_num_seqs=4, -).create_engine_config() - - -@pytest.fixture(autouse=True) -def setup_environment(): - # This is a fake config used for init dist env. - # QKVParallelLinear needs dist env to be initialized. - with set_current_vllm_config(_vllm_config): - temp_file = tempfile.mkstemp()[1] - init_distributed_environment( - 1, - 0, - local_rank=0, - distributed_init_method=f"file://{temp_file}", - backend="gloo") - ensure_model_parallel_initialized(1, 1) - - -@pytest.mark.parametrize("bias", [False, True]) -@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()]) -@pytest.mark.parametrize("fuse_matmuls", [False, True]) -@pytest.mark.parametrize("enable_sp", [False, True]) -def test_jax_qkv_parallel_linear(bias, mesh, fuse_matmuls, enable_sp): - dtype = torch.bfloat16 - with set_current_vllm_config(_vllm_config): - qkv_linear = QKVParallelLinear( - hidden_size=4096, - head_size=128, - total_num_heads=32, - total_num_kv_heads=8, - bias=bias, - params_dtype=dtype, - return_bias=False, - ) - - qkv_linear.weight.data = torch.rand_like(qkv_linear.weight.data) / 10 - if bias: - qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data) - qkv_linear = qkv_linear.to('cpu') - qkv_linear.quant_method.process_weights_after_loading(qkv_linear) - - input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 - input_tensor = input_tensor.to('cpu') - output = qkv_linear(input_tensor).to(dtype) - - # Set jax default device to workaround a layout bug in JAX 0.7.0 and earlier - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): - jax_qkv_linear = JaxQKVParallelLinear(qkv_linear, mesh, fuse_matmuls, - enable_sp) - jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False)) - jax_input_tensor.apply_jax_(jax.device_put, - NamedSharding(mesh, P(None, None))) - with torchax.default_env(): - jax_output = jax_qkv_linear(jax_input_tensor) - # j2t() doens't support bfloat16, so we cast it into float32 as an intermedate step. - jax_output = j2t(jax_output.to(torch.float32)).to(dtype) - - torch.testing.assert_close(output, jax_output) - - -@pytest.mark.parametrize("bias", [False, True]) -@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()]) -@pytest.mark.parametrize("fuse_matmuls", [False, True]) -@pytest.mark.parametrize("enable_sp", [False, True]) -def test_jax_qkv_parallel_linear_w8a8_int8(bias, mesh, fuse_matmuls, - enable_sp): - dtype = torch.bfloat16 - with set_current_vllm_config(_vllm_config): - qkv_linear = QKVParallelLinear( - hidden_size=4096, - head_size=128, - total_num_heads=32, - total_num_kv_heads=8, - bias=bias, - params_dtype=dtype, - return_bias=False, - quant_config=test_utils.gen_vllm_w8a8_int8_config(), - ) - - # Assert we're testing the right code path when quant config is set. - assert isinstance(qkv_linear.quant_method, CompressedTensorsLinearMethod) - assert isinstance(qkv_linear.scheme, CompressedTensorsW8A8Int8) - - qkv_linear.weight.data = torch.randint_like(qkv_linear.weight.data, - low=-128, - high=128) - qkv_linear.weight_scale.data = torch.rand_like( - qkv_linear.weight_scale.data) / 10 - if bias: - qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data) - qkv_linear = qkv_linear.to('cpu') - qkv_linear.quant_method.process_weights_after_loading(qkv_linear) - - input_tensor = torch.rand(16, 4096, dtype=dtype) / 10 - input_tensor = input_tensor.to('cpu') - # Overwrite the torch_xla kernel with a reference implementation, as it's difficult to call torch_xla in tpu_commons and we want to run the ref result on CPU. - with patch( - "vllm.model_executor.layers.quantization.kernels.scaled_mm.xla.XLAScaledMMLinearKernel.apply_weights", - new=test_utils.quantized_matmul_ref): - output = qkv_linear(input_tensor).to(dtype) - - # Set jax default device to workaround a layout bug in JAX 0.7.0 and earlier - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): - jax_qkv_linear = JaxQKVParallelLinear(qkv_linear, mesh, fuse_matmuls, - enable_sp) - jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False)) - jax_input_tensor.apply_jax_(jax.device_put, - NamedSharding(mesh, P(None, None))) - with torchax.default_env(): - jax_output = jax_qkv_linear(jax_input_tensor) - # j2t() doens't support bfloat16, so we cast it into float32 as an intermedate step. - jax_output = j2t(jax_output.to(torch.float32)).to(dtype) - - torch.testing.assert_close(output, jax_output, atol=5, rtol=0.1) diff --git a/tests/models/vllm/test_jax_row_parallel_linear.py b/tests/models/vllm/test_jax_row_parallel_linear.py deleted file mode 100644 index 688cd65d5..000000000 --- a/tests/models/vllm/test_jax_row_parallel_linear.py +++ /dev/null @@ -1,156 +0,0 @@ -import tempfile -from unittest.mock import patch - -import jax -import pytest -import torch -import torchax -import utils as test_utils -from jax.sharding import NamedSharding, PartitionSpec -from torchax.interop import torch_view -from torchax.ops.mappings import j2t, t2j -from vllm.config import set_current_vllm_config -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.engine.arg_utils import EngineArgs -from vllm.model_executor.layers.linear import RowParallelLinear -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \ - CompressedTensorsLinearMethod -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \ - CompressedTensorsW8A8Int8 - -from tpu_commons.models.vllm.jax_row_parallel_linear import \ - JaxRowParallelLinear - -P = PartitionSpec - - -@pytest.fixture(autouse=True) -def setup_environment(): - # This is a fake config used for init dist env. - # RowParallelLinear needs dist env to be initialized. - engine_args = EngineArgs( - model="Qwen/Qwen2-1.5B-Instruct", - max_model_len=64, - max_num_batched_tokens=64, - max_num_seqs=4, - ) - - vllm_config = engine_args.create_engine_config() - - with set_current_vllm_config(vllm_config): - temp_file = tempfile.mkstemp()[1] - init_distributed_environment( - 1, - 0, - local_rank=0, - distributed_init_method=f"file://{temp_file}", - backend="gloo") - ensure_model_parallel_initialized(1, 1) - - -@pytest.mark.parametrize("bias", [False, True]) -@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()]) -@pytest.mark.parametrize("enable_sp", [False, True]) -def test_jax_row_parallel_linear(bias, mesh, enable_sp): - dtype = torch.bfloat16 - - engine_args = EngineArgs( - model="Qwen/Qwen2-1.5B-Instruct", - max_model_len=64, - max_num_batched_tokens=64, - max_num_seqs=4, - ) - vllm_config = engine_args.create_engine_config() - with set_current_vllm_config(vllm_config): - row_linear = RowParallelLinear( - input_size=4096, - output_size=8192, - bias=bias, - params_dtype=dtype, - return_bias=False, - ) - - row_linear.weight.data = torch.rand_like(row_linear.weight.data) / 10 - if bias: - row_linear.bias.data = torch.rand_like(row_linear.bias.data) - row_linear = row_linear.to('cpu') - row_linear.quant_method.process_weights_after_loading(row_linear) - - input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 - input_tensor = input_tensor.to('cpu') - output = row_linear(input_tensor).to(dtype) - - # Set jax default device to workaround a layout bug in JAX 0.7.0 and earlier - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): - jax_row_linear = JaxRowParallelLinear( - row_linear, mesh=mesh, enable_sequence_parallelism=enable_sp) - jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False)) - jax_input_tensor.apply_jax_(jax.device_put, - NamedSharding(mesh, P(None, None))) - with torchax.default_env(): - jax_output = jax_row_linear(jax_input_tensor) - # j2t() doens't support bfloat16, so we cast it into float32 as an intermedate step. - jax_output = j2t(jax_output.to(torch.float32)).to(dtype) - - torch.testing.assert_close(output, jax_output) - - -@pytest.mark.parametrize("bias", [False, True]) -@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()]) -@pytest.mark.parametrize("enable_sp", [False, True]) -def test_jax_row_parallel_linear_w8a8_int8(bias, mesh, enable_sp): - dtype = torch.bfloat16 - - engine_args = EngineArgs( - model="Qwen/Qwen2-1.5B-Instruct", - max_model_len=64, - max_num_batched_tokens=64, - max_num_seqs=4, - ) - vllm_config = engine_args.create_engine_config() - with set_current_vllm_config(vllm_config): - row_linear = RowParallelLinear( - input_size=4096, - output_size=8192, - bias=bias, - params_dtype=dtype, - return_bias=False, - quant_config=test_utils.gen_vllm_w8a8_int8_config(), - ) - - # Assert we're testing the right code path when quant config is set. - assert isinstance(row_linear.quant_method, CompressedTensorsLinearMethod) - assert isinstance(row_linear.scheme, CompressedTensorsW8A8Int8) - - row_linear.weight.data = torch.randint_like(row_linear.weight.data, - low=-128, - high=128) - row_linear.weight_scale.data = torch.rand_like( - row_linear.weight_scale.data) / 10 - if bias: - row_linear.bias.data = torch.rand_like(row_linear.bias.data) - row_linear = row_linear.to('cpu') - row_linear.quant_method.process_weights_after_loading(row_linear) - - input_tensor = torch.rand(16, 4096, dtype=dtype) / 10 - input_tensor = input_tensor.to('cpu') - # Overwrite the torch_xla kernel with a reference implementation, as it's difficult to call torch_xla in tpu_commons and we want to run the ref result on CPU. - with patch( - "vllm.model_executor.layers.quantization.kernels.scaled_mm.xla.XLAScaledMMLinearKernel.apply_weights", - new=test_utils.quantized_matmul_ref): - output = row_linear(input_tensor).to(dtype) - - # Set jax default device to workaround a layout bug in JAX 0.7.0 and earlier - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): - jax_row_linear = JaxRowParallelLinear( - row_linear, mesh=mesh, enable_sequence_parallelism=enable_sp) - jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False)) - jax_input_tensor.apply_jax_(jax.device_put, - NamedSharding(mesh, P(None, None))) - with torchax.default_env(): - jax_output = jax_row_linear(jax_input_tensor) - # j2t() doens't support bfloat16, so we cast it into float32 as an intermedate step. - jax_output = j2t(jax_output.to(torch.float32)).to(dtype) - - torch.testing.assert_close(output, jax_output, atol=5, rtol=0.1) diff --git a/tests/models/vllm/test_pallas_torchax.py b/tests/models/vllm/test_pallas_torchax.py deleted file mode 100644 index 0cc80326a..000000000 --- a/tests/models/vllm/test_pallas_torchax.py +++ /dev/null @@ -1,508 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -import torch -from vllm.attention.backends.abstract import AttentionType -from vllm.config import ModelConfig, SchedulerConfig, VllmConfig - -from tpu_commons.attention.backends.pallas_torchax import ( - PallasAttentionBackend, PallasAttentionBackendImpl, PallasMetadata, - write_to_kv_cache) - - -class TestPallasMetadata: - - def test_init(self): - slot_mapping = torch.tensor([1, 2, 3]) - block_tables = torch.tensor([[1, 2], [3, 4]]) - context_lens = torch.tensor([10, 20]) - query_start_loc = torch.tensor([0, 10]) - num_seqs = torch.tensor([2]) - num_slices = torch.tensor([1]) - - metadata = PallasMetadata(slot_mapping=slot_mapping, - block_tables=block_tables, - context_lens=context_lens, - query_start_loc=query_start_loc, - num_seqs=num_seqs, - num_slices=num_slices) - - assert torch.equal(metadata.slot_mapping, slot_mapping) - assert torch.equal(metadata.block_tables, block_tables) - assert torch.equal(metadata.context_lens, context_lens) - assert torch.equal(metadata.query_start_loc, query_start_loc) - assert torch.equal(metadata.num_seqs, num_seqs) - assert torch.equal(metadata.num_slices, num_slices) - - def test_tree_flatten_unflatten(self): - slot_mapping = torch.tensor([1, 2, 3]) - block_tables = torch.tensor([[1, 2], [3, 4]]) - context_lens = torch.tensor([10, 20]) - query_start_loc = torch.tensor([0, 10]) - num_seqs = torch.tensor([2]) - num_slices = torch.tensor([1]) - - metadata = PallasMetadata(slot_mapping=slot_mapping, - block_tables=block_tables, - context_lens=context_lens, - query_start_loc=query_start_loc, - num_seqs=num_seqs, - num_slices=num_slices) - - children, aux_data = metadata.tree_flatten() - reconstructed = PallasMetadata.tree_unflatten(aux_data, children) - - assert torch.equal(reconstructed.slot_mapping, metadata.slot_mapping) - assert torch.equal(reconstructed.block_tables, metadata.block_tables) - assert torch.equal(reconstructed.context_lens, metadata.context_lens) - assert torch.equal(reconstructed.query_start_loc, - metadata.query_start_loc) - assert torch.equal(reconstructed.num_seqs, metadata.num_seqs) - assert torch.equal(reconstructed.num_slices, metadata.num_slices) - assert aux_data is None - - -class TestPallasAttentionBackend: - - def test_get_state_cls(self): - from vllm.attention.backends.utils import CommonAttentionState - assert PallasAttentionBackend.get_state_cls() == CommonAttentionState - - def test_get_name(self): - assert PallasAttentionBackend.get_name() == "PALLAS_VLLM_V1" - - def test_get_impl_cls(self): - assert PallasAttentionBackend.get_impl_cls( - ) == PallasAttentionBackendImpl - - def test_get_metadata_cls(self): - assert PallasAttentionBackend.get_metadata_cls() == PallasMetadata - - def test_get_kv_cache_shape(self): - num_blocks = 10 - block_size = 16 - num_kv_heads = 8 - head_size = 256 - - shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, block_size, num_kv_heads, head_size) - - expected_shape = (num_blocks, block_size, num_kv_heads * 2, head_size) - assert shape == expected_shape - - def test_get_kv_cache_shape_unaligned_head_size(self): - num_blocks = 10 - block_size = 16 - num_kv_heads = 8 - head_size = 96 # Not aligned to 128 - - shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, block_size, num_kv_heads, head_size) - - # 96 should be padded to 128 - padded_head_size = 128 - expected_shape = (num_blocks, block_size, num_kv_heads * 2, - padded_head_size) - assert shape == expected_shape - - def test_swap_blocks_raises_error(self): - src_kv_cache = torch.empty(0) - dst_kv_cache = torch.empty(0) - src_to_dst = torch.empty(0) - - with pytest.raises( - RuntimeError, - match="swap_blocks is not used for the TPU backend"): - PallasAttentionBackend.swap_blocks(src_kv_cache, dst_kv_cache, - src_to_dst) - - def test_get_min_page_size(self): - model_config = MagicMock(spec=ModelConfig) - model_config.max_model_len = 2048 - - scheduler_config = MagicMock(spec=SchedulerConfig) - scheduler_config.max_num_seqs = 256 - - vllm_config = MagicMock(spec=VllmConfig) - vllm_config.model_config = model_config - vllm_config.scheduler_config = scheduler_config - - min_page_size = PallasAttentionBackend.get_min_page_size(vllm_config) - assert min_page_size > 0 - # Should be a power of 2 - assert (min_page_size & (min_page_size - 1)) == 0 - - def test_get_page_size(self): - model_config = MagicMock(spec=ModelConfig) - model_config.max_model_len = 2048 - - vllm_config = MagicMock(spec=VllmConfig) - vllm_config.model_config = model_config - - page_size = PallasAttentionBackend.get_page_size(vllm_config) - assert 16 <= page_size <= 256 - # Should be a power of 2 - assert (page_size & (page_size - 1)) == 0 - - def test_get_page_size_small_model_len(self): - model_config = MagicMock(spec=ModelConfig) - model_config.max_model_len = 64 # Small model length - - vllm_config = MagicMock(spec=VllmConfig) - vllm_config.model_config = model_config - - page_size = PallasAttentionBackend.get_page_size(vllm_config) - assert page_size == 16 - - def test_get_page_size_large_model_len(self): - model_config = MagicMock(spec=ModelConfig) - model_config.max_model_len = 8192 # Large model length - - vllm_config = MagicMock(spec=VllmConfig) - vllm_config.model_config = model_config - - page_size = PallasAttentionBackend.get_page_size(vllm_config) - assert page_size == 256 - - -class TestPallasAttentionBackendImpl: - - def test_init_valid_params(self): - impl = PallasAttentionBackendImpl( - num_heads=32, - head_size=128, - scale=0.088, - num_kv_heads=8, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="auto", - attn_type=AttentionType.DECODER, - ) - - assert impl.num_heads == 32 - assert impl.head_size == 128 - assert impl.scale == 0.088 - assert impl.num_kv_heads == 8 - assert impl.num_queries_per_kv == 4 - assert impl.sliding_window is None - - def test_init_with_alibi_slopes_raises_error(self): - with pytest.raises(NotImplementedError, - match="Alibi slopes is not supported"): - PallasAttentionBackendImpl( - num_heads=32, - head_size=128, - scale=0.088, - num_kv_heads=8, - alibi_slopes=[1.0, 2.0], - sliding_window=None, - kv_cache_dtype="auto", - attn_type=AttentionType.DECODER, - ) - - def test_init_with_fp8_kv_cache_raises_error(self): - with pytest.raises(NotImplementedError, - match="FP8 KV cache dtype is not supported"): - PallasAttentionBackendImpl( - num_heads=32, - head_size=128, - scale=0.088, - num_kv_heads=8, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="fp8", - attn_type=AttentionType.DECODER, - ) - - def test_init_with_encoder_attention_raises_error(self): - with pytest.raises(NotImplementedError, - match="Encoder self-attention"): - PallasAttentionBackendImpl( - num_heads=32, - head_size=128, - scale=0.088, - num_kv_heads=8, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="auto", - attn_type=AttentionType.ENCODER, - ) - - @patch( - 'tpu_commons.attention.backends.pallas_torchax.ragged_paged_attention') - @patch('tpu_commons.attention.backends.pallas_torchax.get_forward_context') - def test_forward_empty_kv_cache(self, mock_get_context, mock_attention): - impl = PallasAttentionBackendImpl( - num_heads=32, - head_size=128, - scale=0.088, - num_kv_heads=8, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="auto", - attn_type=AttentionType.DECODER, - ) - - layer = MagicMock() - layer._k_scale_float = 1.0 - layer._v_scale_float = 1.0 - - query = torch.randn(2, 4096) # 2 tokens, 32 heads * 128 head_size - key = torch.randn(2, 1024) # 2 tokens, 8 kv_heads * 128 head_size - value = torch.randn(2, 1024) - kv_cache = torch.empty(0) # Empty cache - - metadata = PallasMetadata(slot_mapping=torch.tensor([0, 1]), - block_tables=torch.tensor([[0, 1]]), - context_lens=torch.tensor([2]), - query_start_loc=torch.tensor([0, 2]), - num_seqs=torch.tensor([1]), - num_slices=torch.tensor([1])) - - result = impl.forward(layer, query, key, value, kv_cache, metadata) - assert result.shape == query.shape - mock_attention.assert_not_called() - - def test_forward_with_output_scale_raises_error(self): - impl = PallasAttentionBackendImpl( - num_heads=32, - head_size=128, - scale=0.088, - num_kv_heads=8, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="auto", - attn_type=AttentionType.DECODER, - ) - - layer = MagicMock() - query = torch.randn(2, 512) - key = torch.randn(2, 256) - value = torch.randn(2, 256) - kv_cache = torch.randn(10, 16, 16, 128) - metadata = MagicMock() - output_scale = torch.tensor([1.0]) - - with pytest.raises(NotImplementedError, - match="fused output quantization"): - impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - output_scale=output_scale) - - def test_init_with_irope_warning(self): - with patch('tpu_commons.attention.backends.pallas_torchax.logger' - ) as mock_logger: - _ = PallasAttentionBackendImpl( - num_heads=32, - head_size=128, - scale=0.088, - num_kv_heads=8, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="auto", - attn_type=AttentionType.DECODER, - use_irope=True, - ) - mock_logger.warning_once.assert_called_once_with( - "Using irope in Pallas is not supported yet, it will fall back " - "to global attention for long context.") - - @patch( - 'tpu_commons.attention.backends.pallas_torchax.ragged_paged_attention') - @patch('tpu_commons.attention.backends.pallas_torchax.get_forward_context') - @patch('tpu_commons.attention.backends.pallas_torchax.write_to_kv_cache') - def test_forward_full_flow(self, mock_write_kv, mock_get_context, - mock_attention): - impl = PallasAttentionBackendImpl( - num_heads=32, - head_size=128, - scale=0.088, - num_kv_heads=8, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="auto", - attn_type=AttentionType.DECODER, - ) - - layer = MagicMock() - layer._k_scale_float = 1.0 - layer._v_scale_float = 1.0 - layer.kv_cache = {} - - mock_context = MagicMock() - mock_context.virtual_engine = 0 - mock_get_context.return_value = mock_context - - query = torch.randn(2, 4096) # 2 tokens, 32 heads * 128 head_size - key = torch.randn(2, 1024) # 2 tokens, 8 kv_heads * 128 head_size - value = torch.randn(2, 1024) - kv_cache = torch.randn(10, 16, 16, 128) # Non-empty cache - - metadata = PallasMetadata(slot_mapping=torch.tensor([0, 1]), - block_tables=torch.tensor([[0, 1]]), - context_lens=torch.tensor([2]), - query_start_loc=torch.tensor([0, 2]), - num_seqs=torch.tensor([1]), - num_slices=torch.tensor([1])) - - # Mock write_to_kv_cache to return the same cache - mock_write_kv.return_value = kv_cache - - # Mock attention output - mock_attention.return_value = torch.randn(2, 32, 128) - - result = impl.forward(layer, query, key, value, kv_cache, metadata) - - # Verify write_to_kv_cache was called - mock_write_kv.assert_called_once() - - # Verify ragged_paged_attention was called - mock_attention.assert_called_once() - - # Check result shape - assert result.shape == (2, 4096) - - @patch( - 'tpu_commons.attention.backends.pallas_torchax.ragged_paged_attention') - @patch('tpu_commons.attention.backends.pallas_torchax.get_forward_context') - def test_forward_with_kv_sharing(self, mock_get_context, mock_attention): - impl = PallasAttentionBackendImpl( - num_heads=32, - head_size=128, - scale=0.088, - num_kv_heads=8, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="auto", - attn_type=AttentionType.DECODER, - kv_sharing_target_layer_name="earlier_layer", - ) - - layer = MagicMock() - layer._k_scale_float = 1.0 - layer._v_scale_float = 1.0 - - query = torch.randn(2, 4096) # 2 tokens, 32 heads * 128 head_size - key = torch.randn(2, 1024) # 2 tokens, 8 kv_heads * 128 head_size - value = torch.randn(2, 1024) - kv_cache = torch.randn(10, 16, 16, 128) - - metadata = PallasMetadata(slot_mapping=torch.tensor([0, 1]), - block_tables=torch.tensor([[0, 1]]), - context_lens=torch.tensor([2]), - query_start_loc=torch.tensor([0, 2]), - num_seqs=torch.tensor([1]), - num_slices=torch.tensor([1])) - - mock_attention.return_value = torch.randn(2, 32, 128) - - result = impl.forward(layer, query, key, value, kv_cache, metadata) - - # Verify get_forward_context was not called (KV cache sharing skips write) - mock_get_context.assert_not_called() - assert result.shape == (2, 4096) - - @patch( - 'tpu_commons.attention.backends.pallas_torchax.ragged_paged_attention') - @patch('tpu_commons.attention.backends.pallas_torchax.get_forward_context') - @patch('tpu_commons.attention.backends.pallas_torchax.write_to_kv_cache') - def test_forward_with_head_padding(self, mock_write_kv, mock_get_context, - mock_attention): - impl = PallasAttentionBackendImpl( - num_heads=8, - head_size=96, # Not aligned to 128, will need padding - scale=0.25, - num_kv_heads=8, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="auto", - attn_type=AttentionType.DECODER, - ) - - layer = MagicMock() - layer._k_scale_float = 1.0 - layer._v_scale_float = 1.0 - layer.kv_cache = {} - - mock_context = MagicMock() - mock_context.virtual_engine = 0 - mock_get_context.return_value = mock_context - - query = torch.randn(2, 768) # 2 tokens, 8 heads * 96 head_size - key = torch.randn(2, 768) # 2 tokens, 8 kv_heads * 96 head_size - value = torch.randn(2, 768) - kv_cache = torch.randn(10, 16, 16, 128) # Padded head size - - metadata = PallasMetadata(slot_mapping=torch.tensor([0, 1]), - block_tables=torch.tensor([[0, 1]]), - context_lens=torch.tensor([2]), - query_start_loc=torch.tensor([0, 2]), - num_seqs=torch.tensor([1]), - num_slices=torch.tensor([1])) - - mock_write_kv.return_value = kv_cache - # Return padded output that will be trimmed - mock_attention.return_value = torch.randn(2, 8, 128) - - result = impl.forward(layer, query, key, value, kv_cache, metadata) - - # Result should be trimmed back to original head size - assert result.shape == (2, 768) # 8 heads * 96 original head_size - - -@patch('tpu_commons.attention.backends.pallas_torchax.call_jax') -@patch('tpu_commons.attention.backends.pallas_torchax.kv_cache_update') -def test_write_to_kv_cache(mock_kv_cache_update, mock_call_jax): - # Mock the JAX function call to return the same kv_cache - mock_call_jax.return_value = torch.randn(160, 16, 128) # reshaped size - - key = torch.randn(2, 8, 128) # 2 tokens, 8 kv_heads, 128 head_size - value = torch.randn(2, 8, 128) # 2 tokens, 8 kv_heads, 128 head_size - kv_cache = torch.randn( - 10, 16, 16, - 128) # 10 blocks, 16 block_size, 16 kv_heads*2, 128 head_size - slot_mapping = torch.tensor([[0, 1, 0], [16, 17, 0], [0, 0, 0]]) - num_slices = torch.tensor([1]) - - result = write_to_kv_cache(key, value, kv_cache, slot_mapping, num_slices) - - # Verify the JAX function was called - mock_call_jax.assert_called_once() - - # Check the result shape matches input kv_cache - assert result.shape == kv_cache.shape - - # Verify kv_cache_update was passed to call_jax - args, kwargs = mock_call_jax.call_args - assert args[0] == mock_kv_cache_update - assert kwargs['page_size'] == 16 - - -def test_write_to_kv_cache_tensor_shapes(): - # Create tensors with correct shapes: key/value should be flattened - key = torch.randn(3, 1024) # 3 tokens, 8 kv_heads * 128 head_size - value = torch.randn(3, 1024) # 3 tokens, 8 kv_heads * 128 head_size - kv_cache = torch.randn( - 5, 8, 16, 128) # 5 blocks, 8 block_size, 16 kv_heads*2, 128 head_size - slot_mapping = torch.tensor([[0, 1, 2], [8, 9, 10], [0, 0, 0]]) - num_slices = torch.tensor([1]) - - with patch('tpu_commons.attention.backends.pallas_torchax.call_jax' - ) as mock_call_jax: - # Mock return value with correct shape - mock_call_jax.return_value = torch.randn(40, 16, 128) # reshaped size - - result = write_to_kv_cache(key, value, kv_cache, slot_mapping, - num_slices) - - # Check input tensor shapes passed to JAX - args, kwargs = mock_call_jax.call_args - kv_input = args[1] # The concatenated kv tensor - - # kv should be [3 tokens, 16 combined_heads, 128 padded_head_size] - assert kv_input.shape == (3, 16, 128) - assert result.shape == kv_cache.shape diff --git a/tests/models/vllm/utils.py b/tests/models/vllm/utils.py deleted file mode 100644 index 128694401..000000000 --- a/tests/models/vllm/utils.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Optional - -import jax -import torch -import torch.nn.functional as F -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \ - CompressedTensorsConfig - - -def get_spmd_mesh(num_devices: int = 1): - axis_names = ("data", "model") - devices = sorted(jax.devices(), key=lambda d: d.id)[0:num_devices] - mesh_shape = (1, len(devices)) - return jax.make_mesh(mesh_shape, axis_names, devices=devices) - - -def quantized_matmul_ref(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - w_q = getattr(layer, "weight") - w_s = getattr(layer, "weight_scale") - output = F.linear(x, w_q.type(torch.bfloat16)) - output = output * w_s - - if bias is not None: - output = output + bias - return output - - -def gen_vllm_w8a8_int8_config(): - return CompressedTensorsConfig.from_config({ - "format": "int-quantized", - "config_groups": { - "group_0": { - "input_activations": { - "dynamic": True, - "num_bits": 8, - "strategy": "token", - "symmetric": True, - "type": "int" - }, - "targets": ["Linear"], - "weights": { - "dynamic": False, - "num_bits": 8, - "strategy": "channel", - "symmetric": True, - "type": "int" - } - } - } - }) diff --git a/tpu_commons/attention/backends/pallas_torchax.py b/tpu_commons/attention/backends/pallas_torchax.py index b47453c8e..dbffa3d9c 100644 --- a/tpu_commons/attention/backends/pallas_torchax.py +++ b/tpu_commons/attention/backends/pallas_torchax.py @@ -1,23 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass -from typing import Optional +import functools +from typing import Optional, Tuple +import jax import torch -from jax.tree_util import register_pytree_node_class -from torchax.interop import call_jax +from jax.sharding import Mesh +from torchax.interop import jax_view, torch_view from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) -from vllm.attention.backends.utils import CommonAttentionState -from vllm.config import VllmConfig -from vllm.forward_context import ForwardContext, get_forward_context -from vllm.utils import cdiv, next_power_of_2 +from vllm.model_executor.models.utils import extract_layer_index from tpu_commons.logger import init_logger +from tpu_commons.models.jax.attention import attention +from tpu_commons.models.jax.attention_metadata import AttentionMetadata # Register custom op dispatcher. -from tpu_commons.models.torchax.torchax_wrapper import (kv_cache_update, - ragged_paged_attention) -from tpu_commons.utils import TPU_HEAD_SIZE_ALIGNMENT +from tpu_commons.models.vllm.vllm_model_wrapper_context import \ + get_vllm_model_wrapper_context logger = init_logger(__name__) @@ -32,90 +31,6 @@ def get_name() -> str: def get_impl_cls() -> type["PallasAttentionBackendImpl"]: return PallasAttentionBackendImpl - @staticmethod - def get_metadata_cls() -> type["PallasMetadata"]: - return PallasMetadata - - @staticmethod - def get_state_cls() -> type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> tuple[int, ...]: - padded_head_size = cdiv( - head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - return (num_blocks, block_size, num_kv_heads * 2, padded_head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - raise RuntimeError("swap_blocks is not used for the TPU backend.") - - # In recent TPU generations, up to v6e, the SMEM size is 1MB. The - # block_tables within the PallasMetadata constitute almost the entire SMEM - # requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here - # we simply make sure that the size is smaller than half of SMEM capacity. - @staticmethod - def get_min_page_size(vllm_config: VllmConfig) -> int: - max_num_page_per_req = (1024 * 1024 // 2 // - vllm_config.scheduler_config.max_num_seqs // 4) - min_page_size = cdiv(vllm_config.model_config.max_model_len, - max_num_page_per_req) - min_page_size = 1 << (min_page_size - 1).bit_length() - return min_page_size - - # TPU has limited SREGs (scalar registers), if page_size is too small, we - # can spill SREGs easily which leads to bad performance. The strategy we - # apply here is trying to split max-model-len to 16 pages which make the - # spill less likely. Meanwhile we make sure the page size is in [16, 256]. - @staticmethod - def get_page_size(vllm_config: VllmConfig) -> int: - page_size = next_power_of_2( - vllm_config.model_config.max_model_len) // 16 - if page_size <= 16: - return 16 - if page_size >= 256: - return 256 - return page_size - - -@dataclass -@register_pytree_node_class -class PallasMetadata: - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # Used in the PallasAttentionBackendImpl - slot_mapping: torch.Tensor - block_tables: torch.Tensor - context_lens: torch.Tensor - query_start_loc: torch.Tensor - num_seqs: torch.Tensor - num_slices: torch.Tensor - - def tree_flatten(self): - children = (self.slot_mapping, self.block_tables, self.context_lens, - self.query_start_loc, self.num_seqs, self.num_slices) - aux_data = None - return (children, aux_data) - - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls(*children) - class PallasAttentionBackendImpl(AttentionImpl): @@ -164,114 +79,84 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: PallasMetadata, + attn_metadata: AttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Forward pass with Pallas attention. - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ if output_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for PallasAttentionBackendImpl") - - # For determine_available_memory case. - if kv_cache.numel() == 0: - if output is None: - output = torch.ones_like(query) - return output - - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - num_tokens, hidden_size = query.shape - query = query.view(num_tokens, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: - padded_head_size = cdiv( - self.head_size, - TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - query = torch.nn.functional.pad( - query, (0, padded_head_size - self.head_size), value=0.0) - key = torch.nn.functional.pad( - key, (0, padded_head_size - self.head_size), value=0.0) - value = torch.nn.functional.pad( - value, (0, padded_head_size - self.head_size), value=0.0) - - if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0: - # Write input keys and values to the KV cache. - # Skip this if sharing KV cache with an earlier attention layer. - slot_mapping = attn_metadata.slot_mapping - kv_cache = write_to_kv_cache(key, value, kv_cache, slot_mapping, - attn_metadata.num_slices) - forward_context: ForwardContext = get_forward_context() - layer.kv_cache[forward_context.virtual_engine] = kv_cache - - ragged_paged_attention_op = ragged_paged_attention - - output = ragged_paged_attention_op( - query, - kv_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - attn_metadata.query_start_loc, - attn_metadata.num_seqs, - # By default, the system utilizes optimized block size and - # vmem_limit_bytes parameters from the kernel repository. However, - # these can be manually adjusted for debugging if necessary. - num_kv_pages_per_block=None, - num_queries_per_block=None, - vmem_limit_bytes=100 * 1024 * 1024, - use_kernel=True, - sm_scale=self.scale, - sliding_window=self.sliding_window, - soft_cap=self.logits_soft_cap, - ) - - if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: - output = output[:, :, :self.head_size] - - return output.reshape(num_tokens, hidden_size) - - -def write_to_kv_cache(key: torch.Tensor, value: torch.Tensor, - kv_cache: torch.Tensor, slot_mapping: torch.Tensor, - num_slices: torch.Tensor) -> torch.Tensor: - """ Write the key and values to the KV cache. - - Args: - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] - slot_mapping = [3, padded_num_slices] - num_slices = [1] - - """ - num_blocks, block_size, num_combined_kv_heads, head_size = kv_cache.shape - head_size = cdiv(head_size, - TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - num_kv_heads = num_combined_kv_heads // 2 - - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, - head_size) - - kv_cache = kv_cache.reshape(-1, num_combined_kv_heads, head_size) - kv_cache = call_jax(kv_cache_update, - kv, - slot_mapping, - kv_cache, - num_slices, - page_size=block_size) - kv_cache = kv_cache.reshape(num_blocks, block_size, num_combined_kv_heads, - head_size) - return kv_cache + "fused output quantization is not yet supported for " + "PallasAttentionBackendImpl") + + if kv_cache.numel(): + raise RuntimeError( + "KV cache from vLLM Attention layer should be empty but has " + "the size of %s.", kv_cache.numel()) + + del kv_cache # Use kv_cache from vllm wrapper context values instead. + + vllm_model_wrapper_context = get_vllm_model_wrapper_context() + layer_idx = extract_layer_index(layer.layer_name) + kv_cache = vllm_model_wrapper_context.kv_caches[layer_idx] + mesh = vllm_model_wrapper_context.mesh + + new_kv_cache, outputs = _jax_attn_func(kv_cache, jax_view(query), + jax_view(key), jax_view(value), + attn_metadata, mesh, self.scale, + self.head_size, self.num_heads, + self.num_kv_heads) + vllm_model_wrapper_context.kv_caches[layer_idx] = new_kv_cache + + return torch_view(outputs) + + +@functools.partial( + jax.jit, + static_argnums=(5, 6, 7, 8, + 9), # mesh, scale, head_size, num_heads, num_kv_heads + donate_argnums=(0, ), # donate kv_cache +) +def _jax_attn_func( + kv_cache: jax.Array, + q: jax.Array, + k: jax.Array, + v: jax.Array, + attention_metadata: AttentionMetadata, + mesh: Mesh, + scale: float, + head_size: int, + num_heads: int, + num_kv_heads: int, +) -> Tuple[jax.Array, jax.Array]: + del scale # Unused for now, as the attention function applies a default scale. + + # Get shapes from vllm + q_len, q_compute_dim = q.shape + k_len, k_compute_dim = k.shape + assert k.shape == v.shape + assert q_compute_dim == head_size * num_heads + assert k_compute_dim == head_size * num_kv_heads + + # Convert the shapes from vLLM's convetion to what the attention function expects + # bs, num_heads, q_len, head_size + q = q.reshape(q_len, num_heads, head_size) + # bs, num_kv_heads, k_len, head_size + k = k.reshape(k_len, num_kv_heads, head_size) + v = v.reshape(k_len, num_kv_heads, head_size) + + new_kv_cache, outputs = attention( + kv_cache, + q, + k, + v, + attention_metadata, + mesh, + ) + + # Convert the shape back to vLLM's convention + assert outputs.shape[0] == q_len + assert outputs.shape[1] == num_heads + assert outputs.shape[2] == head_size + outputs = outputs.reshape(q_len, q_compute_dim) + + return new_kv_cache, outputs diff --git a/tpu_commons/models/vllm/jax_attention.py b/tpu_commons/models/vllm/jax_attention.py deleted file mode 100644 index a2b15dbfa..000000000 --- a/tpu_commons/models/vllm/jax_attention.py +++ /dev/null @@ -1,104 +0,0 @@ -import functools -from typing import Optional, Tuple - -import jax -import torch -import torch.nn -from jax.sharding import Mesh -from torchax.interop import jax_view, torch_view -from vllm.attention import Attention as VllmAttention -from vllm.model_executor.models.utils import extract_layer_index - -from tpu_commons.models.jax.attention import attention -from tpu_commons.models.jax.attention_metadata import AttentionMetadata -from tpu_commons.models.vllm.vllm_model_wrapper_context import \ - get_vllm_model_wrapper_context - - -@functools.partial( - jax.jit, - static_argnums=(5, 6, 7, 8, - 9), # mesh, scale, head_size, num_heads, num_kv_heads - donate_argnums=(0, ), # donate kv_cache -) -def _jax_attn_func( - kv_cache: jax.Array, - q: jax.Array, - k: jax.Array, - v: jax.Array, - attention_metadata: AttentionMetadata, - mesh: Mesh, - scale: float, - head_size: int, - num_heads: int, - num_kv_heads: int, -) -> Tuple[jax.Array, jax.Array]: - del scale # Unused for now, as the attention function applies a default scale. - - # Get shapes from vllm - q_len, q_compute_dim = q.shape - k_len, k_compute_dim = k.shape - assert k.shape == v.shape - assert q_compute_dim == head_size * num_heads - assert k_compute_dim == head_size * num_kv_heads - - # Convert the shapes from vLLM's convetion to what the attention function expects - # bs, num_heads, q_len, head_size - q = q.reshape(q_len, num_heads, head_size) - # bs, num_kv_heads, k_len, head_size - k = k.reshape(k_len, num_kv_heads, head_size) - v = v.reshape(k_len, num_kv_heads, head_size) - - new_kv_cache, outputs = attention( - kv_cache, - q, - k, - v, - attention_metadata, - mesh, - ) - - # Convert the shape back to vLLM's convention - assert outputs.shape[0] == q_len - assert outputs.shape[1] == num_heads - assert outputs.shape[2] == head_size - outputs = outputs.reshape(q_len, q_compute_dim) - - return new_kv_cache, outputs - - -class JaxAttention(torch.nn.Module): - - def __init__( - self, - vllm_attn: VllmAttention, - mesh: Mesh, - ) -> None: - super().__init__() - - self.num_heads = vllm_attn.num_heads - self.head_size = vllm_attn.head_size - self.scale = vllm_attn.impl.scale - self.num_kv_heads = vllm_attn.num_kv_heads - self.layer_idx = extract_layer_index(vllm_attn.layer_name) - self.mesh = mesh - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - # For some alternate attention backends like MLA the attention output - # shape does not match the q shape, so we optionally let the model - # definition specify the output tensor shape. - output_shape: Optional[torch.Size] = None, - ) -> torch.Tensor: - vllm_model_wrapper_context = get_vllm_model_wrapper_context() - new_kv_cache, outputs = _jax_attn_func( - vllm_model_wrapper_context.kv_caches[self.layer_idx], jax_view(q), - jax_view(k), jax_view(v), - vllm_model_wrapper_context.attention_metadata, self.mesh, - self.scale, self.head_size, self.num_heads, self.num_kv_heads) - vllm_model_wrapper_context.kv_caches[self.layer_idx] = new_kv_cache - - return torch_view(outputs) diff --git a/tpu_commons/models/vllm/jax_linear_common.py b/tpu_commons/models/vllm/jax_linear_common.py index cfcb2a3e7..573bb8b9e 100644 --- a/tpu_commons/models/vllm/jax_linear_common.py +++ b/tpu_commons/models/vllm/jax_linear_common.py @@ -1,56 +1,37 @@ -import functools +from typing import Optional, Union import jax import jax.numpy as jnp +import torch from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P +from torchax.interop import torch_view +from torchax.ops.mappings import t2j from tpu_commons.kernels.quantized_matmul.kernel import quantized_matmul_kernel -_quantized_matmul_kernel = functools.partial(quantized_matmul_kernel, - quantize_activation=True) - - -def forward_unqunatized(x: jax.Array, w: jax.Array, b: jax.Array): - output = jnp.einsum('mn,pn->mp', x, w) - if b is not None: - output = output + b - return output - - -def forward_w8a8_int8_col_parallel(x: jax.Array, w: jax.Array, b: jax.Array, - w_s: jax.Array, mesh: Mesh): - x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P())) - output = shard_map(_quantized_matmul_kernel, - mesh=mesh, - in_specs=(P(), P('model', None), P('model')), - out_specs=(P(None, 'model')), - check_rep=False)(x, w, w_s) - if b is not None: - output = output + b - return output - - -def forward_w8a8_int8_row_parallel(x: jax.Array, w: jax.Array, b: jax.Array, - w_s: jax.Array, mesh: Mesh, - reduce_results: bool): - x = jax.lax.with_sharding_constraint(x, - NamedSharding(mesh, P(None, 'model'))) - output = shard_map(_quantized_matmul_kernel, - mesh=mesh, - in_specs=(P(None, 'model'), P(None, 'model'), P()), - out_specs=(P(None, 'model')), - check_rep=False)(x, w, w_s) - if reduce_results: - output = shard_map(lambda x: jax.lax.psum(x, axis_name='model'), - mesh=mesh, - in_specs=P(None, 'model'), - out_specs=P(), - check_rep=False)(output) - if b is not None: - output = output + b - return output + +def sharded_quantized_matmul(x: jax.Array, w_q: jax.Array, w_s: jax.Array, + mesh: Mesh, weight_sharding: P): + out_axis, in_axis = weight_sharding + x_sharding = P(None, in_axis) + scale_sharding = P(out_axis, ) + out_sharding = P(None, out_axis) + + x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, x_sharding)) + + def wrapper(x, w_q, w_s): + output = quantized_matmul_kernel(x, w_q, w_s, quantize_activation=True) + if in_axis: + output = jax.lax.psum(output, axis_name=in_axis) + return output + + return shard_map(wrapper, + mesh=mesh, + in_specs=(x_sharding, weight_sharding, scale_sharding), + out_specs=(out_sharding), + check_rep=False)(x, w_q, w_s) def reorder_concatenated_tensor_for_sharding(concatenated_tensor: jax.Array, @@ -65,14 +46,12 @@ def reorder_concatenated_tensor_for_sharding(concatenated_tensor: jax.Array, The output is: AAABBCAAABBCAAABBCAAABBC In other words, it reorders the input tensor into 4 segements, with each segment corresponding to a shard and being AAABBC. - Args: concatenated_tensor: the tensor, concatenated on the dimension specified by `dim`. split_sizes: each individual tensor's size on the dimension specified by `dim`. n_shards: num of shards. dim: the dimension on which the concatenated_tensor is concatenated. """ - # Split the concatenated tensor into individual tensors. split_tensors = [] start_offset = 0 @@ -107,7 +86,6 @@ def slice_sharded_tensor_for_concatenation(sharded_tensor: jax.Array, C | C | C | C Shard0 Shard1 Shard2 Shard3 In other words, each individual tensor is a slice of the input tensor with the same sharding. - Args: sharded_tensor: the input tensor, sharded on the last dim. split_sizes: each individual tensor's size on the last dim. @@ -130,3 +108,36 @@ def slice_sharded_tensor_for_concatenation(sharded_tensor: jax.Array, start_offset = end_offset return split_tensors + + +def torch_to_jax_param( + tensor: torch.Tensor, + sharding: NamedSharding, + output_sizes: Optional[int], + n_shards: int, + fused: bool, +) -> Union[torch.nn.Parameter, torch.nn.ParameterList]: + if output_sizes is None: + output_sizes = [tensor.shape[0]] + + tensor = t2j(tensor, use_dlpack=False) + if fused: + tensor = reorder_concatenated_tensor_for_sharding( + tensor, output_sizes, n_shards, 0) + tensor = jax.device_put(tensor, sharding) + param = torch.nn.Parameter(torch_view(tensor), requires_grad=False) + else: + tensors = [] + start_offset = 0 + for size in output_sizes: + end_offset = start_offset + size + + tensor_split = tensor[start_offset:end_offset] + tensor_split = jax.device_put(tensor_split, sharding) + tensor_split = torch.nn.Parameter(torch_view(tensor_split), + requires_grad=False) + tensors.append(tensor_split) + + start_offset = end_offset + param = torch.nn.ParameterList(tensors) + return param diff --git a/tpu_commons/models/vllm/jax_merged_column_parallel_linear.py b/tpu_commons/models/vllm/jax_merged_column_parallel_linear.py deleted file mode 100644 index 1bdf84199..000000000 --- a/tpu_commons/models/vllm/jax_merged_column_parallel_linear.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch -from jax.sharding import Mesh -from vllm.model_executor.layers.linear import MergedColumnParallelLinear - -from tpu_commons.models.vllm.jax_merged_column_parallel_linear_core import \ - JaxMergedColumnParallelLinearCore - - -class JaxMergedColumnParallelLinear(JaxMergedColumnParallelLinearCore): - - def __init__(self, merged_col_parallel_linear: torch.nn.Module, mesh: Mesh, - fuse_matmuls: bool, enable_sequence_parallelism: bool): - assert isinstance(merged_col_parallel_linear, - MergedColumnParallelLinear) - super().__init__( - merged_col_parallel_linear, - mesh, - "JaxMergedColumnParallelLinear", - fuse_matmuls=fuse_matmuls, - enable_sequence_parallelism=enable_sequence_parallelism) diff --git a/tpu_commons/models/vllm/jax_merged_column_parallel_linear_core.py b/tpu_commons/models/vllm/jax_merged_column_parallel_linear_core.py deleted file mode 100644 index d842f10be..000000000 --- a/tpu_commons/models/vllm/jax_merged_column_parallel_linear_core.py +++ /dev/null @@ -1,248 +0,0 @@ -import jax -import jax.numpy as jnp -import torch -from jax.sharding import Mesh, NamedSharding, PartitionSpec -from torch.nn.parameter import Parameter -from torchax.interop import torch_view -from torchax.ops.mappings import t2j -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \ - CompressedTensorsLinearMethod -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \ - CompressedTensorsW8A8Int8 - -from tpu_commons.models.vllm.jax_linear_common import ( - forward_unqunatized, forward_w8a8_int8_col_parallel, - reorder_concatenated_tensor_for_sharding, - slice_sharded_tensor_for_concatenation) -from tpu_commons.utils import TPU_SECOND_LAST_MINOR - -P = PartitionSpec - - -class JaxMergedColumnParallelLinearCore(torch.nn.Module): - """ A common class to implement Column Parallel Linear layer whose weight are merged from a list of smaller weight tensors, e.g. vLLM's MergedColumnParallelLinear and QKVParallelLinear layer. """ - - def __init__(self, vllm_col_par_linear: torch.nn.Module, mesh: Mesh, - name: str, fuse_matmuls: bool, enable_sequence_parallelism): - super().__init__() - - self.gather_output = vllm_col_par_linear.gather_output - self.skip_bias_add = vllm_col_par_linear.skip_bias_add - self.return_bias = vllm_col_par_linear.return_bias - self.output_sizes = vllm_col_par_linear.output_sizes - self.mesh = mesh - self.name = name - self.fuse_matmuls = fuse_matmuls - self.has_bias = vllm_col_par_linear.bias is not None - self.enable_sequence_parallelism = enable_sequence_parallelism - self.n_matmuls = len(self.output_sizes) - assert vllm_col_par_linear.tp_size == 1, ( - "The model has to be loaded with TP== 1 in order to run in Jax SPMD." - ) - - self.w8q8_int8_quant = False - if isinstance(vllm_col_par_linear.quant_method, - CompressedTensorsLinearMethod) and isinstance( - vllm_col_par_linear.scheme, - CompressedTensorsW8A8Int8): - self.w8q8_int8_quant = True - - if self.fuse_matmuls: - self._load_weights_from_merged_linear_fused(vllm_col_par_linear) - self._shard_weight_fused(mesh) - else: - self._load_weights_from_merged_linear_split(vllm_col_par_linear) - self._shard_weight_split(mesh) - - def _shard_weight_fused(self, mesh: Mesh): - self.weight.apply_jax_(jax.device_put, - NamedSharding(mesh, P('model', None))) - - if self.bias is not None: - self.bias.apply_jax_(jax.device_put, - NamedSharding(mesh, P('model'))) - - if self.w8q8_int8_quant: - self.weight_scale.apply_jax_(jax.device_put, - NamedSharding(mesh, P('model'))) - - def _shard_weight_split(self, mesh: Mesh): - # Shard all weights in the weight_list - for i in range(self.n_matmuls): - weight = getattr(self, f"weight_{i}") - weight.apply_jax_(jax.device_put, - NamedSharding(mesh, P('model', None))) - setattr(self, f"weight_{i}", weight) - - if self.has_bias: - for i in range(self.n_matmuls): - bias = getattr(self, f"bias_{i}") - bias.apply_jax_(jax.device_put, - NamedSharding(mesh, P('model', ))) - setattr(self, f"bias_{i}", bias) - - if self.w8q8_int8_quant: - for i in range(self.n_matmuls): - weight_scale = getattr(self, f"weight_scale_{i}") - weight_scale.apply_jax_(jax.device_put, - NamedSharding(mesh, P('model'))) - setattr(self, f"weight_scale_{i}", weight_scale) - - def _load_weights_from_merged_linear_fused( - self, vllm_col_par_linear: torch.nn.Module): - n_shards = self.mesh.shape['model'] - for _, output_size in enumerate(self.output_sizes): - assert output_size % n_shards == 0, "Each output size in MergedColumnParallelLinear must be a multiple of num chips in the 'model' axis." - - concat_weight = t2j(vllm_col_par_linear.weight.data, use_dlpack=False) - weight = reorder_concatenated_tensor_for_sharding(concat_weight, - self.output_sizes, - n_shards, - dim=0) - weight = Parameter(torch_view(weight), requires_grad=False) - self.register_parameter("weight", weight) - - if vllm_col_par_linear.bias is not None: - concat_bias = t2j(vllm_col_par_linear.bias.data, use_dlpack=False) - bias = reorder_concatenated_tensor_for_sharding(concat_bias, - self.output_sizes, - n_shards, - dim=0) - bias = Parameter(torch_view(bias), requires_grad=False) - self.register_parameter("bias", bias) - else: - self.register_parameter("bias", None) - - if self.w8q8_int8_quant: - assert self.weight.jax().dtype == jnp.int8 - concat_weight_scale = t2j(vllm_col_par_linear.weight_scale.data, - use_dlpack=False) - weight_scale = reorder_concatenated_tensor_for_sharding( - concat_weight_scale, self.output_sizes, n_shards, dim=0) - weight_scale = Parameter(torch_view(weight_scale), - requires_grad=False) - self.register_parameter("weight_scale", weight_scale) - else: - self.register_parameter("weight_scale", None) - - def _load_weights_from_merged_linear_split( - self, vllm_col_par_linear: torch.nn.Module): - output_sizes = vllm_col_par_linear.output_sizes - concat_weight = torch_view( - t2j(vllm_col_par_linear.weight.data, use_dlpack=False)) - concat_bias = None - if self.has_bias: - concat_bias = torch_view( - t2j(vllm_col_par_linear.bias.data, use_dlpack=False)) - if self.w8q8_int8_quant: - concat_weight_scale = torch_view( - t2j(vllm_col_par_linear.weight_scale.data, use_dlpack=False)) - start_offset = 0 - for i, size in enumerate(output_sizes): - weight = Parameter(concat_weight[start_offset:start_offset + - size].detach(), - requires_grad=False) - setattr(self, f"weight_{i}", weight) - - if concat_bias is not None: - bias = Parameter(concat_bias[start_offset:start_offset + - size].detach(), - requires_grad=False) - setattr(self, f"bias_{i}", bias) - else: - setattr(self, f"bias_{i}", None) - - if self.w8q8_int8_quant: - assert weight.jax().dtype == jnp.int8 - weight_scale = Parameter( - concat_weight_scale[start_offset:start_offset + - size].detach(), - requires_grad=False) - setattr(self, f"weight_scale_{i}", weight_scale) - else: - setattr(self, f"weight_scale_{i}", None) - - start_offset += size - - def forward_fused(self, input: torch.Tensor): - x = input.jax() - weight = self.weight.jax() - bias = None if (self.skip_bias_add - or self.bias is None) else self.bias.jax() - if self.w8q8_int8_quant: - weight_scale = self.weight_scale.jax( - ) if self.w8q8_int8_quant else None - output = forward_w8a8_int8_col_parallel(x, weight, bias, - weight_scale, self.mesh) - else: - output = forward_unqunatized(x, weight, bias) - - n_shards = self.mesh.shape['model'] - split_outputs = slice_sharded_tensor_for_concatenation( - output, self.output_sizes, n_shards) - if self.gather_output: - split_outputs = [ - jax.lax.with_sharding_constraint(t, - NamedSharding(self.mesh, P())) - for t in split_outputs - ] - output = torch_view(jnp.concatenate(split_outputs, axis=-1)) - - if not self.return_bias: - return output - - if self.skip_bias_add or self.bias is None: - output_bias = None - else: - split_biases = slice_sharded_tensor_for_concatenation( - self.bias, self.output_sizes, n_shards) - output_bias = torch_view(jnp.concatenate(split_biases, axis=-1)) - return output, output_bias - - def forward_split(self, input): - x = input.jax() - split_outputs = [] - for i in range(self.n_matmuls): - weight = getattr(self, f"weight_{i}").jax() - bias = getattr(self, f"bias_{i}") - bias = None if (self.skip_bias_add or bias is None) else bias.jax() - if self.w8q8_int8_quant: - weight_scale = getattr(self, f"weight_scale_{i}").jax() - output = forward_w8a8_int8_col_parallel( - x, weight, bias, weight_scale, self.mesh) - else: - output = forward_unqunatized(x, weight, bias) - split_outputs.append(output) - if self.gather_output: - split_outputs = [ - jax.lax.with_sharding_constraint(t, - NamedSharding(self.mesh, P())) - for t in split_outputs - ] - output = torch_view(jnp.concatenate(split_outputs, axis=-1)) - - if not self.return_bias: - return output - - if self.skip_bias_add or not self.has_bias: - output_bias = None - else: - split_biases = [ - getattr(self, f"bias_{i}").jax() - for i, _ in enumerate(self.output_sizes) - ] - output_bias = torch_view(jnp.concatenate(split_biases, axis=-1)) - return output, output_bias - - def forward(self, input: torch.Tensor): - with jax.named_scope(self.name): - if self.enable_sequence_parallelism: - token_num = input.shape[0] - # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR - if token_num // self.mesh.shape[ - 'model'] >= TPU_SECOND_LAST_MINOR: - input.shard_(NamedSharding(self.mesh, P('model', None))) - if self.fuse_matmuls: - return self.forward_fused(input) - else: - return self.forward_split(input) diff --git a/tpu_commons/models/vllm/jax_qkv_parallel_linear.py b/tpu_commons/models/vllm/jax_qkv_parallel_linear.py deleted file mode 100644 index b7ea9a6af..000000000 --- a/tpu_commons/models/vllm/jax_qkv_parallel_linear.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -from jax.sharding import Mesh -from vllm.model_executor.layers.linear import QKVParallelLinear - -from tpu_commons.models.vllm.jax_merged_column_parallel_linear_core import \ - JaxMergedColumnParallelLinearCore - - -class JaxQKVParallelLinear(JaxMergedColumnParallelLinearCore): - - def __init__(self, qkv_linear: torch.nn.Module, mesh: Mesh, - fuse_matmuls: bool, enable_sequence_parallelism: bool): - assert isinstance(qkv_linear, QKVParallelLinear) - super().__init__( - qkv_linear, - mesh, - "JaxQKVParallelLinear", - fuse_matmuls=fuse_matmuls, - enable_sequence_parallelism=enable_sequence_parallelism) diff --git a/tpu_commons/models/vllm/jax_row_parallel_linear.py b/tpu_commons/models/vllm/jax_row_parallel_linear.py deleted file mode 100644 index 8ae23582c..000000000 --- a/tpu_commons/models/vllm/jax_row_parallel_linear.py +++ /dev/null @@ -1,108 +0,0 @@ -from typing import Optional - -import jax -import jax.numpy as jnp -import torch -from jax.sharding import Mesh, NamedSharding, PartitionSpec -from torch.nn.parameter import Parameter -from torchax.interop import torch_view -from torchax.ops.mappings import t2j -from vllm.model_executor.layers.linear import RowParallelLinear -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \ - CompressedTensorsLinearMethod -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \ - CompressedTensorsW8A8Int8 - -from tpu_commons.models.vllm.jax_linear_common import ( - forward_unqunatized, forward_w8a8_int8_row_parallel) -from tpu_commons.utils import TPU_SECOND_LAST_MINOR - -P = PartitionSpec - - -class JaxRowParallelLinear(torch.nn.Module): - - def __init__(self, row_linear: torch.nn.Module, mesh: Mesh, - enable_sequence_parallelism: bool): - super().__init__() - assert isinstance(row_linear, RowParallelLinear) - - self.mesh = mesh - self.reduce_results = row_linear.reduce_results - self.skip_bias_add = row_linear.skip_bias_add - self.return_bias = row_linear.return_bias - self.enable_sequence_parallelism = enable_sequence_parallelism - - self.w8q8_int8_quant = False - if isinstance(row_linear.quant_method, - CompressedTensorsLinearMethod) and isinstance( - row_linear.scheme, CompressedTensorsW8A8Int8): - self.w8q8_int8_quant = True - - self.weight: Parameter - self.bias: Optional[Parameter] - self.weight_scale: Optional[Parameter] - - self._load_weights_from_vllm_layer(row_linear) - self._shard_weight(mesh) - - def _shard_weight(self, mesh: Mesh): - self.weight.apply_jax_(jax.device_put, - NamedSharding(mesh, P(None, 'model'))) - - if self.bias is not None: - self.bias.apply_jax_(jax.device_put, NamedSharding( - mesh, P())) # column parallel can't shard the bias - - if self.w8q8_int8_quant: - self.weight_scale.apply_jax_(jax.device_put, - NamedSharding(mesh, P())) - - def _load_weights_from_vllm_layer(self, row_linear: torch.nn.Module): - weight = Parameter(torch_view( - t2j(row_linear.weight.data, use_dlpack=False)), - requires_grad=False) - self.register_parameter("weight", weight) - - if row_linear.bias is not None: - bias = Parameter(torch_view( - t2j(row_linear.bias.data, use_dlpack=False)), - requires_grad=False) - self.register_parameter("bias", bias) - else: - self.register_parameter("bias", None) - - if self.w8q8_int8_quant: - assert weight.jax().dtype == jnp.int8 - weight_scale = Parameter(torch_view( - t2j(row_linear.weight_scale.data, use_dlpack=False)), - requires_grad=False) - self.register_parameter("weight_scale", weight_scale) - else: - self.register_parameter("weight_scale", None) - - def forward(self, input: torch.Tensor): - with jax.named_scope("JaxRowParallelLinear"): - x = input.jax() - weight = self.weight.jax() - bias = None if (self.skip_bias_add - or self.bias is None) else self.bias.jax() - if self.w8q8_int8_quant: - weight_scale = self.weight_scale.jax() - output = forward_w8a8_int8_row_parallel( - x, weight, bias, weight_scale, self.mesh, - self.reduce_results) - else: - output = forward_unqunatized(x, weight, bias) - output = torch_view(output) - - if not self.return_bias: - return output - output_bias = self.bias if self.skip_bias_add else None - if self.enable_sequence_parallelism: - token_num = input.shape[0] - # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR - if token_num // self.mesh.shape[ - 'model'] >= TPU_SECOND_LAST_MINOR: - output.shard_(NamedSharding(self.mesh, P('model', None))) - return output, output_bias diff --git a/tpu_commons/models/vllm/quantization/__init__.py b/tpu_commons/models/vllm/quantization/__init__.py new file mode 100644 index 000000000..1ed7a58bf --- /dev/null +++ b/tpu_commons/models/vllm/quantization/__init__.py @@ -0,0 +1,32 @@ +import copy + +from jax.sharding import Mesh +from vllm.config import VllmConfig +from vllm.model_executor.layers.quantization.base_config import \ + QuantizationConfig + +from tpu_commons.models.vllm.quantization.common import JaxCommonConfig +from tpu_commons.models.vllm.quantization.compressed_tensors.compressed_tensors import \ + JaxCompressedTensorsConfig # noqa: E501 +from tpu_commons.models.vllm.quantization.unquantized import \ + JaxUnquantizedConfig + + +def get_tpu_quantization_config(vllm_config: VllmConfig, + mesh: Mesh) -> QuantizationConfig: + model_config = copy.deepcopy(vllm_config.model_config) + # TODO(kyuyeunk): Add support for "tpu_int8". + method_to_config: dict[str, str] = { + None: JaxUnquantizedConfig, + "compressed-tensors": JaxCompressedTensorsConfig, + } + + if model_config.quantization not in method_to_config: + raise NotImplementedError + quant_config = method_to_config[model_config.quantization] + assert issubclass(quant_config, JaxCommonConfig) + quant_config.set_configs(vllm_config, mesh) + + model_config.quantization = quant_config.get_name() + return VllmConfig.get_quantization_config(model_config, + vllm_config.load_config) diff --git a/tpu_commons/models/vllm/quantization/common.py b/tpu_commons/models/vllm/quantization/common.py new file mode 100644 index 000000000..1200cc26d --- /dev/null +++ b/tpu_commons/models/vllm/quantization/common.py @@ -0,0 +1,91 @@ +import torchax +from jax.sharding import Mesh, PartitionSpec +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, LinearBase, MergedColumnParallelLinear, + QKVParallelLinear, ReplicatedLinear, RowParallelLinear) + +from tpu_commons.models.vllm.jax_merged_column_parallel_linear_fusion_assignments import \ + get_model_matmul_fusion_assignment +from tpu_commons.utils import TPU_SECOND_LAST_MINOR + +P = PartitionSpec + +logger = init_logger(__name__) + + +class JaxCommonLinearConfig: + + def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase): + assert isinstance(layer, LinearBase) + + self.mesh = mesh + self.output_sizes = [layer.output_size] + self.weight_sharding = P(None, None) + self.fuse_matmuls = True + self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism + self.input_sharding = None + self.output_sharding = None + + if isinstance(layer, RowParallelLinear): + self.weight_sharding = P(None, 'model') + if self.enable_sequence_parallelism: + self.output_sharding = P('model', None) + elif isinstance(layer, ColumnParallelLinear): + self.weight_sharding = P('model', None) + if self.enable_sequence_parallelism: + self.input_sharding = P('model', None) + + if isinstance(layer, MergedColumnParallelLinear) or isinstance( + layer, QKVParallelLinear): + self.output_sizes = layer.output_sizes + + self.fuse_matmuls = get_model_matmul_fusion_assignment( + vllm_config.model_config.model, + vllm_config.scheduler_config.max_num_batched_tokens, + vllm_config.parallel_config.tensor_parallel_size, + layer._get_name()) + elif isinstance(layer, ReplicatedLinear): + self.weight_sharding = P(None, None) + else: + logger.warning( + "Unsupported linear layer type of %s. Can potentially yield " + " bad performance.", type(layer)) + + self.bias_sharding = P(self.weight_sharding[0]) + self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1) + + def get_input_sharding(self, x: torchax.tensor.Tensor): + if self.enable_sequence_parallelism: + token_num = x.shape[0] + # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR + if token_num // self.mesh.shape['model'] >= TPU_SECOND_LAST_MINOR: + return self.input_sharding + else: + return None + return self.input_sharding + + def get_output_sharding(self, x: torchax.tensor.Tensor): + if self.enable_sequence_parallelism: + token_num = x.shape[0] + # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR + if token_num // self.mesh.shape['model'] >= TPU_SECOND_LAST_MINOR: + return self.output_sharding + else: + return None + return self.output_sharding + + +class JaxCommonConfig: + vllm_config: VllmConfig + mesh: Mesh + + @classmethod + def set_configs(cls, vllm_config: VllmConfig, mesh: Mesh): + cls.vllm_config = vllm_config + cls.mesh = mesh + + def get_linear_config(self, layer: LinearBase) -> JaxCommonLinearConfig: + assert isinstance(layer, LinearBase) + return JaxCommonLinearConfig(self.vllm_config, self.mesh, layer) diff --git a/tpu_commons/models/vllm/quantization/compressed_tensors/compressed_tensors.py b/tpu_commons/models/vllm/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 000000000..61e9713e2 --- /dev/null +++ b/tpu_commons/models/vllm/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,109 @@ +from typing import Optional + +import torch +from jax.sharding import PartitionSpec +from vllm.attention.layer import Attention +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization import \ + register_quantization_config +from vllm.model_executor.layers.quantization.base_config import \ + QuantizeMethodBase # noqa: E501 +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, CompressedTensorsKVCacheMethod, + CompressedTensorsLinearMethod, CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + find_matched_target, is_activation_quantization_format, + should_ignore_layer) + +from tpu_commons.models.vllm.quantization.common import JaxCommonConfig +from tpu_commons.models.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \ + JaxCompressedTensorsW8A8Int8 +from tpu_commons.models.vllm.quantization.unquantized import \ + JaxUnquantizedConfig + +P = PartitionSpec +logger = init_logger(__name__) + + +@register_quantization_config("jax-compressed-tensors") +class JaxCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig): + + @classmethod + def get_name(cls) -> str: + return "jax-compressed-tensors" + + def get_scheme(self, + layer: torch.nn.Module, + layer_name: Optional[str] = None + ) -> Optional["CompressedTensorsScheme"]: + """ + compressed-tensors supports non uniform in the following way: + + targets of config_groups: There can be N config_groups which each + have a quantization scheme. Each config_group has a list of targets + which can be a full layer_name, a regex for a layer_name, or + an nn.Module name. + + Detect whether a layer_name is found in any target and + use the quantization scheme corresponding to the matched target + to select the CompressedTensorsScheme used for inference. + """ + + # Will be empty for models with only sparsity + weight_quant = input_quant = None + if self.target_scheme_map: + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=self.target_scheme_map.keys(), + fused_mapping=self.packed_modules_mapping) + + scheme_dict = self.target_scheme_map[matched_target] + weight_quant = scheme_dict.get("weights") + input_quant = scheme_dict.get("input_activations") + format = scheme_dict.get("format") + + if weight_quant is None: + logger.warning_once("Acceleration for non-quantized schemes is " + "not supported by Compressed Tensors. " + "Falling back to UnquantizedLinearMethod") + return None + + # TODO(kyuyeunk): Add support for different act_quant_format + act_quant_format = is_activation_quantization_format( # noqa: F841 + format + ) if format is not None else is_activation_quantization_format( + self.quant_format) + + linear_config = self.get_linear_config(layer) + # TODO(kyuyeunk): Add support for FP8 w8a8. + if self._is_dynamic_token_w8a8(weight_quant, input_quant): + return JaxCompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=False, + input_symmetric=input_quant.symmetric, + jax_config=linear_config, + ) + raise NotImplementedError( + "No compressed-tensors compatible scheme was found.") + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional[QuantizeMethodBase]: + if should_ignore_layer(prefix, + ignore=self.ignore, + fused_mapping=self.packed_modules_mapping): + return JaxUnquantizedConfig.get_quant_method(self, layer, prefix) + if isinstance(layer, LinearBase): + scheme = self.get_scheme(layer=layer, layer_name=prefix) + if scheme is None: + return JaxUnquantizedConfig.get_quant_method( + self, layer, prefix) + layer.scheme = scheme + return CompressedTensorsLinearMethod(self) + if isinstance(layer, Attention): + return CompressedTensorsKVCacheMethod(self) + return None diff --git a/tpu_commons/models/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/tpu_commons/models/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py new file mode 100644 index 000000000..e78f47455 --- /dev/null +++ b/tpu_commons/models/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -0,0 +1,136 @@ +from typing import Optional + +import jax +import jax.numpy as jnp +import torch +from compressed_tensors.quantization import QuantizationStrategy +from jax.sharding import NamedSharding, PartitionSpec +from torchax.interop import torch_view +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \ + CompressedTensorsW8A8Int8 +from vllm.model_executor.layers.quantization.utils.w8a8_utils import \ + convert_to_channelwise + +from tpu_commons.models.vllm.jax_linear_common import ( + sharded_quantized_matmul, slice_sharded_tensor_for_concatenation, + torch_to_jax_param) +from tpu_commons.models.vllm.quantization.common import JaxCommonLinearConfig + +P = PartitionSpec +logger = init_logger(__name__) + + +class JaxCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8): + + def __init__(self, strategy: str, is_static_input_scheme: bool, + input_symmetric: bool, jax_config: JaxCommonLinearConfig): + super().__init__(strategy, is_static_input_scheme, input_symmetric) + + self.jax_config = jax_config + self.is_channelwise = (self.strategy == QuantizationStrategy.CHANNEL), + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight = torch_to_jax_param( + layer.weight, + NamedSharding(self.jax_config.mesh, + self.jax_config.weight_sharding), + self.jax_config.output_sizes, + self.jax_config.n_shards, + self.jax_config.fuse_matmuls, + ) + delattr(layer, 'weight') + layer.weight = weight + + weight_scale = layer.weight_scale + is_fused_module = len(layer.logical_widths) > 1 + if is_fused_module and not self.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + weight_scale = weight_scale.squeeze(-1) + + weight_scale = torch_to_jax_param( + weight_scale, + NamedSharding(self.jax_config.mesh, self.jax_config.bias_sharding), + self.jax_config.output_sizes, + self.jax_config.n_shards, + self.jax_config.fuse_matmuls, + ) + delattr(layer, 'weight_scale') + layer.weight_scale = weight_scale + + if layer.bias is not None and not layer.skip_bias_add: + if layer.return_bias: + logger.warning_once("Bias might return incorrect value.") + + bias = torch_to_jax_param( + layer.bias, + NamedSharding(self.jax_config.mesh, + self.jax_config.bias_sharding), + self.jax_config.output_sizes, + self.jax_config.n_shards, + self.jax_config.fuse_matmuls, + ) + delattr(layer, 'bias') + layer.bias = bias + + # TODO(kyuyeunk): Support static range input quantization. + assert getattr(layer, 'input_scale', None) is None + assert getattr(layer, 'input_zero_point', None) is None + assert getattr(layer, 'azp_adj', None) is None + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + with jax.named_scope(layer._get_name()): + if self.jax_config.fuse_matmuls: + out = self._apply_fused(layer, x, bias) + else: + out = self._apply_split(layer, x, bias) + + return out + + def _apply_fused(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + x_jax = x.jax() + weight_jax = layer.weight.jax() + weight_scale_jax = layer.weight_scale.jax() + + outs = sharded_quantized_matmul( + x_jax, + weight_jax, + weight_scale_jax, + self.jax_config.mesh, + self.jax_config.weight_sharding, + ) + if bias is not None and not layer.skip_bias_add: + outs += bias.jax() + + outs = slice_sharded_tensor_for_concatenation( + outs, self.jax_config.output_sizes, self.jax_config.n_shards) + out = jnp.concatenate(outs, axis=-1) + return torch_view(out) + + def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + assert isinstance(layer.weight, torch.nn.ParameterList) + + x_jax = x.jax() + outs = [] + for i, (weight, weight_scale) in enumerate( + zip(layer.weight, layer.wieght_scale)): + weight_jax = weight.jax() + weight_scale_jax = weight_scale.jax() + + out = sharded_quantized_matmul( + x_jax, + weight_jax, + weight_scale_jax, + self.jax_config.mesh, + self.jax_config.weight_sharding, + ) + if bias is not None and not layer.skip_bias_add: + out += bias[i].jax() + + outs.append(out) + out = jnp.concatenate(outs, axis=-1) + return torch_view(out) diff --git a/tpu_commons/models/vllm/quantization/unquantized.py b/tpu_commons/models/vllm/quantization/unquantized.py new file mode 100644 index 000000000..6560d3630 --- /dev/null +++ b/tpu_commons/models/vllm/quantization/unquantized.py @@ -0,0 +1,146 @@ +from typing import Any, Optional + +import jax +import jax.numpy as jnp +import torch +from jax.sharding import NamedSharding, PartitionSpec +from torchax.interop import torch_view +from vllm.attention.layer import Attention +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, UnquantizedFusedMoEMethod) +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import \ + register_quantization_config +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) + +from tpu_commons.models.vllm.jax_linear_common import ( + slice_sharded_tensor_for_concatenation, torch_to_jax_param) +from tpu_commons.models.vllm.quantization.common import (JaxCommonConfig, + JaxCommonLinearConfig) + +P = PartitionSpec +logger = init_logger(__name__) + + +@register_quantization_config("jax-unquantized") +class JaxUnquantizedConfig(QuantizationConfig, JaxCommonConfig): + + @classmethod + def get_name(cls) -> str: + return "jax-unquantized" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.float32, torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 0 # Always supported + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] # No extra configs required. + + @classmethod + def from_config(cls, _: dict[str, Any]) -> "JaxUnquantizedConfig": + return cls() + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional[QuantizeMethodBase]: + if isinstance(layer, LinearBase): + linear_config = self.get_linear_config(layer) + return JaxUnquantizedLinearMethod(linear_config) + if isinstance(layer, FusedMoE): + return UnquantizedFusedMoEMethod(layer.moe_config) + if isinstance(layer, Attention): + return None + return None + + +class JaxUnquantizedLinearMethod(UnquantizedLinearMethod): + + def __init__(self, jax_config: JaxCommonLinearConfig): + self.jax_config = jax_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight = torch_to_jax_param( + layer.weight, + NamedSharding(self.jax_config.mesh, + self.jax_config.weight_sharding), + self.jax_config.output_sizes, + self.jax_config.n_shards, + self.jax_config.fuse_matmuls, + ) + delattr(layer, 'weight') + layer.weight = weight + + if layer.bias is not None and not layer.skip_bias_add: + if layer.return_bias: + logger.warning_once("Bias might return incorrect value.") + + bias = torch_to_jax_param( + layer.bias, + NamedSharding(self.jax_config.mesh, + self.jax_config.bias_sharding), + self.jax_config.output_sizes, + self.jax_config.n_shards, + self.jax_config.fuse_matmuls, + ) + delattr(layer, 'bias') + layer.bias = bias + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + with jax.named_scope(layer._get_name()): + if in_sharding := self.jax_config.get_input_sharding(x): + x.shard_(NamedSharding(self.jax_config.mesh, in_sharding)) + + if self.jax_config.fuse_matmuls: + out = self._apply_fused(layer, x, bias) + else: + out = self._apply_split(layer, x, bias) + + if out_sharding := self.jax_config.get_output_sharding(out): + out.shard_(NamedSharding(self.jax_config.mesh, out_sharding)) + + return out + + def _apply_fused(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + x_jax = x.jax() + weight_jax = layer.weight.jax() + + outs = jnp.einsum('mn,pn->mp', x_jax, weight_jax) + if bias is not None and not layer.skip_bias_add: + outs += bias.jax() + + outs = slice_sharded_tensor_for_concatenation( + outs, self.jax_config.output_sizes, self.jax_config.n_shards) + out = jnp.concatenate(outs, axis=-1) + return torch_view(out) + + def _apply_split(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert isinstance(layer.weight, torch.nn.ParameterList) + + x_jax = x.jax() + outs = [] + for i, weight in enumerate(layer.weight): + weight_jax = weight.jax() + + out = jnp.einsum('mn,pn->mp', x_jax, weight_jax) + if bias is not None and not layer.skip_bias_add: + out += bias[i].jax() + + outs.append(out) + out = jnp.concatenate(outs, axis=-1) + return torch_view(out) diff --git a/tpu_commons/models/vllm/sharding.py b/tpu_commons/models/vllm/sharding.py index 176b1cf18..8f2f5b07b 100644 --- a/tpu_commons/models/vllm/sharding.py +++ b/tpu_commons/models/vllm/sharding.py @@ -5,90 +5,20 @@ import torch import torchax from jax.sharding import Mesh, NamedSharding, PartitionSpec -from torch.nn.parameter import Parameter from torch.utils import _pytree as pytree -from torchax.interop import extract_all_buffers, torch_view +from torchax.interop import torch_view from torchax.ops.mappings import t2j -from vllm.attention import Attention as VllmAttention from vllm.config import VllmConfig from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import \ - UnquantizedLinearMethod # yapf: disable -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) from tpu_commons.logger import init_logger -from tpu_commons.models.vllm.jax_attention import JaxAttention from tpu_commons.models.vllm.jax_fused_moe import JaxFusedMoE -from tpu_commons.models.vllm.jax_merged_column_parallel_linear import \ - JaxMergedColumnParallelLinear -from tpu_commons.models.vllm.jax_merged_column_parallel_linear_fusion_assignments import \ - get_model_matmul_fusion_assignment -from tpu_commons.models.vllm.jax_qkv_parallel_linear import \ - JaxQKVParallelLinear -from tpu_commons.models.vllm.jax_row_parallel_linear import \ - JaxRowParallelLinear P = PartitionSpec logger = init_logger(__name__) -def shard_attention(layer: torch.nn.Module, mesh: Mesh, - vllm_config: VllmConfig): - return JaxAttention(layer, mesh) - - -def shard_qkv_parallel_linear(layer: torch.nn.Module, mesh: Mesh, - vllm_config: VllmConfig): - assert isinstance(layer, QKVParallelLinear) - jax_layer = JaxQKVParallelLinear( - layer, - mesh, - shard_qkv_parallel_linear.fuse_matmuls, - enable_sequence_parallelism=vllm_config.compilation_config.pass_config. - enable_sequence_parallelism) - return jax_layer - - -def shard_merged_column_parallel_linear(layer: torch.nn.Module, mesh: Mesh, - vllm_config: VllmConfig): - assert isinstance(layer, MergedColumnParallelLinear) - jax_layer = JaxMergedColumnParallelLinear( - layer, - mesh, - shard_merged_column_parallel_linear.fuse_matmuls, - enable_sequence_parallelism=vllm_config.compilation_config.pass_config. - enable_sequence_parallelism) - return jax_layer - - -def shard_column_parallel_linear(layer: torch.nn.Module, mesh: Mesh, - vllm_config: VllmConfig): - assert isinstance(layer, ColumnParallelLinear) - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - raise ValueError( - "tpu_commons torchax ColumnParallelLinear doesn't support quantization" - ) - w = Parameter(torch_view(t2j(layer.weight)), requires_grad=False) - layer.weight = w.apply_jax_(jax.device_put, - NamedSharding(mesh, P('model', None))) - return layer - - -def shard_row_parallel_linear(layer: torch.nn.Module, mesh: Mesh, - vllm_config: VllmConfig): - assert isinstance(layer, RowParallelLinear) - jax_layer = JaxRowParallelLinear( - layer, - mesh, - enable_sequence_parallelism=vllm_config.compilation_config.pass_config. - enable_sequence_parallelism) - return jax_layer - - def shard_fused_moe(layer: torch.nn.Module, mesh: Mesh, vllm_config: VllmConfig): assert isinstance(layer, FusedMoE) @@ -97,11 +27,7 @@ def shard_fused_moe(layer: torch.nn.Module, mesh: Mesh, MODULE_TYPE_TO_WRAPPING_FUNC = { - VllmAttention: shard_attention, - QKVParallelLinear: shard_qkv_parallel_linear, - MergedColumnParallelLinear: shard_merged_column_parallel_linear, - ColumnParallelLinear: shard_column_parallel_linear, - RowParallelLinear: shard_row_parallel_linear, + # TODO(kyuyeunk): Refactor this layer to use vLLM APIs. FusedMoE: shard_fused_moe, } @@ -154,33 +80,71 @@ def _move_to_tpu_replicated(x): return torch_view(x).apply_jax_(jax.device_put, NamedSharding(mesh, P())) - tp_size = vllm_config.parallel_config.tensor_parallel_size - shard_qkv_parallel_linear.fuse_matmuls = get_model_matmul_fusion_assignment( - vllm_config.model_config.model, - vllm_config.scheduler_config.max_num_batched_tokens, tp_size, - "QKVParallelLinear") - shard_merged_column_parallel_linear.fuse_matmuls = get_model_matmul_fusion_assignment( - vllm_config.model_config.model, - vllm_config.scheduler_config.max_num_batched_tokens, tp_size, - "MergedColumnParallelLinear") - with jax.default_device(jax.devices("cpu")[0]), torchax.default_env(): shard_parallel_layers_to_tpu(model, mesh, vllm_config) # For other weight tensors, repliate them on all the TPU chips. - params, buffers = extract_all_buffers(model) + params, buffers, variables = extract_all_buffers(model) fmt_size = functools.partial(humanize.naturalsize, binary=True) - for qual_name, x in {**params, **buffers}.items(): + for qual_name, x in {**params, **buffers, **variables}.items(): if _is_unmoved_tensor(x): tensor_size = fmt_size(x.nbytes) logger.debug( f"{qual_name=} is not sharded, {tensor_size=}, {x.shape=}, {x.dtype=}" ) - params, buffers = pytree.tree_map_only(_is_unmoved_tensor, - _move_to_tpu_replicated, - (params, buffers)) + params, buffers, variables = pytree.tree_map_only( + _is_unmoved_tensor, _move_to_tpu_replicated, + (params, buffers, variables)) + set_all_buffers(model, {}, {}, variables) params_and_buffers = {**params, **buffers} return params_and_buffers + + +def extract_all_buffers(m: torch.nn.Module): + params = {} + buffers = {} + variables = {} + + def extract_one(module, prefix): + for k in dir(module): + v = getattr(module, k, None) + if v is None: + continue + + qual_name = prefix + k + if isinstance(v, torch.nn.Parameter): + params[qual_name] = v + elif isinstance(v, torch.nn.ParameterList): + for i, param in enumerate(v): + params[qual_name + f'.{i}'] = param + elif k in module._buffers: + buffers[qual_name] = v + elif isinstance(v, torch.Tensor): + variables[qual_name] = v + + for name, child in module.named_children(): + extract_one(child, prefix + name + '.') + + extract_one(m, '') + return params, buffers, variables + + +def set_all_buffers(m, params, buffers, variables): + + def set_one(module, prefix): + for k in dir(module): + qual_name = prefix + k + if (potential_v := buffers.get(qual_name)) is not None or ( + potential_v := variables.get(qual_name)) is not None: + setattr(module, k, potential_v) + elif (potential_v := params.get(qual_name)) is not None: + # print(k, potential_v) + # setattr(module, k, torch.nn.Parameter(potential_v)) + module.register_parameter(k, potential_v) + for name, child in module.named_children(): + set_one(child, prefix + name + '.') + + set_one(m, '') diff --git a/tpu_commons/models/vllm/vllm_model_wrapper.py b/tpu_commons/models/vllm/vllm_model_wrapper.py index ed74ac58b..926fb4e93 100644 --- a/tpu_commons/models/vllm/vllm_model_wrapper.py +++ b/tpu_commons/models/vllm/vllm_model_wrapper.py @@ -17,11 +17,13 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader import get_model as vllm_get_model from vllm.sequence import IntermediateTensors from tpu_commons.logger import init_logger from tpu_commons.models.jax.attention_metadata import AttentionMetadata +from tpu_commons.models.vllm.quantization import get_tpu_quantization_config from tpu_commons.models.vllm.sharding import shard_model_to_tpu from tpu_commons.models.vllm.vllm_model_wrapper_context import ( get_vllm_model_wrapper_context, set_vllm_model_wrapper_context) @@ -83,6 +85,9 @@ def __init__(self, vllm_config: VllmConfig, rng: PRNGKey, mesh: Mesh): self.rng = rng self.mesh = mesh + self.vllm_config.quant_config = get_tpu_quantization_config( + self.vllm_config, self.mesh) + def load_weights(self): # Initialize the vLLM distribution layer as a single chip environment, # we'll swap the model's parallel modules with TPU SPMD equivalents. @@ -149,14 +154,14 @@ def step_fun( params_and_buffers, # this has been wrapped into a torchax TorchValue kv_caches: List[jax.Array], input_ids: jax.Array, - attention_metadata: AttentionMetadata, + attn_metadata: AttentionMetadata, *args, ) -> Tuple[List[jax.Array], jax.Array]: with torchax.default_env(), set_vllm_model_wrapper_context( - kv_caches=kv_caches, - attention_metadata=attention_metadata, - ): + kv_caches=kv_caches, mesh=self.mesh), set_forward_context( + attn_metadata=attn_metadata, + vllm_config=self.vllm_config): # We need to wrap args from jax land into TorchValue with # torch_view in order to call the Torch function. hidden_states = torch.func.functional_call( @@ -164,8 +169,7 @@ def step_fun( torch_view(params_and_buffers), kwargs={ "input_ids": torch_view(input_ids), - "positions": - torch_view(attention_metadata.input_positions), + "positions": torch_view(attn_metadata.input_positions), "intermediate_tensors": None, "inputs_embeds": None, }, @@ -192,7 +196,8 @@ def compute_logits_func( params_and_buffers: Any, hidden_states: jax.Array, ) -> jax.Array: - with torchax.default_env(): + with torchax.default_env(), set_vllm_model_wrapper_context( + kv_caches=None, mesh=self.mesh): logits = torch.func.functional_call( self.model, torch_view(params_and_buffers), diff --git a/tpu_commons/models/vllm/vllm_model_wrapper_context.py b/tpu_commons/models/vllm/vllm_model_wrapper_context.py index 9f1df8ea6..486bc50ab 100644 --- a/tpu_commons/models/vllm/vllm_model_wrapper_context.py +++ b/tpu_commons/models/vllm/vllm_model_wrapper_context.py @@ -3,8 +3,7 @@ from typing import List, Optional, Tuple import jax - -from tpu_commons.models.jax.attention_metadata import AttentionMetadata +from jax.sharding import Mesh KVCache = Tuple[jax.Array, jax.Array] @@ -12,7 +11,7 @@ @dataclass class VllmModelWrapperContext: kv_caches: List[KVCache] - attention_metadata: AttentionMetadata + mesh: Mesh _vllm_model_wrapper_context: Optional[VllmModelWrapperContext] = None @@ -30,14 +29,12 @@ def get_vllm_model_wrapper_context() -> VllmModelWrapperContext: def set_vllm_model_wrapper_context( *, kv_caches: List[KVCache], - attention_metadata: AttentionMetadata, + mesh: Mesh, ): global _vllm_model_wrapper_context prev_context = _vllm_model_wrapper_context - _vllm_model_wrapper_context = VllmModelWrapperContext( - kv_caches=kv_caches, - attention_metadata=attention_metadata, - ) + _vllm_model_wrapper_context = VllmModelWrapperContext(kv_caches=kv_caches, + mesh=mesh) try: yield diff --git a/tpu_commons/platforms/tpu_jax.py b/tpu_commons/platforms/tpu_jax.py index a739c8c31..f56a59c80 100644 --- a/tpu_commons/platforms/tpu_jax.py +++ b/tpu_commons/platforms/tpu_jax.py @@ -60,7 +60,7 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, if use_v1: logger.info("Using Pallas V1 backend.") - return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + return "tpu_commons.attention.backends.pallas_torchax.PallasAttentionBackend" else: logger.info("Using Pallas backend.") return "vllm.attention.backends.pallas.PallasAttentionBackend"