diff --git a/.buildkite/features/runai_model_streamer_loader.yml b/.buildkite/features/runai_model_streamer_loader.yml new file mode 100644 index 000000000..3c9cf9c93 --- /dev/null +++ b/.buildkite/features/runai_model_streamer_loader.yml @@ -0,0 +1,26 @@ +# runai_model_streamer_loader +# The RunAI Model Streamer is a high-performance model loader that serves as an +# alternative to the default Hugging Face loader. Instead of downloading a model +# to local disk, it streams the weights from object storage (like GCS) into +# GPU memory. This streaming process is significantly faster than the traditional +# disk-based loading method. +steps: + - label: "Correctness tests for runai_model_streamer_loader" + key: "runai_model_streamer_loader_CorrectnessTest" + soft_fail: true + agents: + queue: tpu_v6e_queue + commands: + - .buildkite/scripts/run_in_docker.sh python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_runai_model_streamer_loader.py::test_correctness + - label: "Record correctness test result for runai_model_streamer_loader" + key: "record_runai_model_streamer_loader_CorrectnessTest" + depends_on: "runai_model_streamer_loader_CorrectnessTest" + env: + CI_TARGET: "runai_model_streamer_loader" + CI_STAGE: "CorrectnessTest" + CI_CATEGORY: "feature support matrix" + agents: + queue: cpu + commands: + - | + .buildkite/scripts/record_step_result.sh runai_model_streamer_loader_CorrectnessTest diff --git a/requirements.txt b/requirements.txt index 393869b29..55b554c07 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ torchvision==0.24.0 pathwaysutils parameterized numba==0.62.1 +runai-model-streamer[s3,gcs]==0.15.0 diff --git a/tests/e2e/test_runai_model_streamer_loader.py b/tests/e2e/test_runai_model_streamer_loader.py new file mode 100644 index 000000000..23e8b54fa --- /dev/null +++ b/tests/e2e/test_runai_model_streamer_loader.py @@ -0,0 +1,90 @@ +# This file contains end-to-end tests for the RunAI Model Streamer loader. +# +# The RunAI Model Streamer is a high-performance model loader that serves as an +# alternative to the default Hugging Face loader. Instead of downloading a model +# to local disk, it streams the weights from object storage (like GCS) into +# GPU memory. This streaming process is significantly faster than the +# traditional disk-based loading method. + +# The tests in this file verify that loading model weights using the +# streamer produces the same results as loading the same model using the +# standard Hugging Face loader. This ensures the correctness of the streamer +# integration. + +# The tests are performed by: +# 1. Loading a model from Google Cloud Storage using the `runai_streamer` format. +# 2. Generating output with this model. +# 3. Loading the same model from Hugging Face using the default loader. +# 4. Generating output with this second model. +# 5. Asserting that the outputs from both models are identical. + +from __future__ import annotations + +import time + +import pytest +from vllm import LLM, SamplingParams + + +@pytest.fixture +def sampling_config(): + return SamplingParams(temperature=0, max_tokens=10, ignore_eos=True) + + +@pytest.fixture +# TODO(amacaskill): Replace with GKE owned GCS bucket. +def gcs_model_name(): + return "gs://vertex-model-garden-public-us/llama3/llama3-8b-hf" + + +@pytest.fixture +def hf_model_name(): + return "meta-llama/Meta-Llama-3-8B" + + +@pytest.fixture +def prompt(): + return "Hello, my name is" + + +def test_correctness( + sampling_config: SamplingParams, + gcs_model_name: str, + hf_model_name: str, + prompt: str, + monkeypatch: pytest.MonkeyPatch, +): + ''' + Compare the outputs of a model loaded from GCS via runai_model_streamer + and a model loaded from Hugging Face. The outputs should be the same. + These tests attempt to use tensor_parallel_size=1. The model is 16GB, + # and v6e has 32GB of HBM, so it will fit. + ''' + # Set ENV variables so that runai_model_streamer uses anonymous GCS access. + monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake-project") + monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true") + monkeypatch.setenv("CLOUD_STORAGE_EMULATOR_ENDPOINT", + "https://storage.googleapis.com") + gcs_llm = LLM(model=gcs_model_name, + load_format="runai_streamer", + max_model_len=128, + max_num_seqs=16, + max_num_batched_tokens=256) + gcs_outputs = gcs_llm.generate([prompt], sampling_config) + gcs_output_text = gcs_outputs[0].outputs[0].text + del gcs_llm + time.sleep(10) # Wait for TPUs to be released + + # Test with Hugging Face model + hf_llm = LLM(model=hf_model_name, + max_model_len=128, + max_num_seqs=16, + max_num_batched_tokens=256) + hf_outputs = hf_llm.generate([prompt], sampling_config) + hf_output_text = hf_outputs[0].outputs[0].text + del hf_llm + time.sleep(10) # Wait for TPUs to be released + + assert gcs_output_text == hf_output_text, ( + f"Outputs do not match! " + f"GCS output: {gcs_output_text}, HF output: {hf_output_text}") diff --git a/tests/models/common/test_model_loader.py b/tests/models/common/test_model_loader.py index c667e6ba4..692b04137 100644 --- a/tests/models/common/test_model_loader.py +++ b/tests/models/common/test_model_loader.py @@ -55,6 +55,7 @@ def vllm_config() -> MagicMock: mock_config.model_config.dtype = jnp.bfloat16 mock_config.load_config = MagicMock() mock_config.load_config.download_dir = None + mock_config.load_config.load_format = "auto" mock_config.additional_config = {} mock_config.cache_config = MagicMock(cache_dtype="auto") return mock_config diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index cc97aef80..4acda58df 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -8,6 +8,9 @@ from torchax.ops.mappings import j2t_dtype from transformers import PretrainedConfig from vllm.config import VllmConfig +from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.runai_streamer_loader import \ + RunaiModelStreamerLoader from vllm.utils.func_utils import supports_kw from tpu_inference import envs @@ -177,7 +180,23 @@ def create_sharded_model(): # the model creation again, otherwise the model forward will have # non-trivial overhead in PjitFunction. with mesh: - model.load_weights(rng) + loader = get_model_loader(vllm_config.load_config) + if isinstance(loader, RunaiModelStreamerLoader): + model_weights = vllm_config.model_config.model + if hasattr(vllm_config.model_config, "model_weights"): + model_weights = vllm_config.model_config.model_weights + weights_iterator = loader._get_weights_iterator( + model_weights, vllm_config.model_config.revision) + # We set the weights iterator at runtime, to prevent having to change + # every model's load_weights signature. This also prevents us from hitting + # a TypeError at runtime if you use the RunaiModelStreamerLoader with any + # flax_nnx model whose load_weights function does not accept the + # weights_iterator keyword argument. + vllm_config.model_config.model_weights_iterator = weights_iterator + model.load_weights(rng) + del vllm_config.model_config.model_weights_iterator + else: + model.load_weights(rng) jit_model = create_jit_model( model, use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model) diff --git a/tpu_inference/models/jax/utils/weight_utils.py b/tpu_inference/models/jax/utils/weight_utils.py index ec8ecaa4f..4afab3225 100644 --- a/tpu_inference/models/jax/utils/weight_utils.py +++ b/tpu_inference/models/jax/utils/weight_utils.py @@ -13,10 +13,12 @@ import jax import jax.numpy as jnp import torch +import torchax from flax import nnx from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P from safetensors import safe_open +from vllm.config import VllmConfig from tpu_inference import envs, utils from tpu_inference.logger import init_logger @@ -265,15 +267,15 @@ def get_default_maps(model_config, mesh: Mesh, bias_pad_map=bias_pad_keys) -def _load_hf_weights_on_thread(vllm_config, - params: nnx.State, - metadata_map: MetadataMap, - mesh: Mesh, - weights_file: str, - filter_regex: str | None = None, - keep_original_dtype_keys_regex: list[str] - | None = None, - exclude_regex: list[str] | None = None): +def _load_and_shard_weight(vllm_config, + params: nnx.State, + shardings: Any, + metadata_map: MetadataMap, + mesh: Mesh, + hf_key: str, + hf_weight: jax.Array, + keep_original_dtype_keys_regex: list[str] + | None = None): name_map = metadata_map.name_map reshape_keys = metadata_map.reshape_map bias_reshape_keys = metadata_map.bias_reshape_map @@ -290,6 +292,119 @@ def _load_hf_weights_on_thread(vllm_config, head_dim = utils.get_padded_head_dim(head_dim_original) head_dim_pad = head_dim - head_dim_original + # Check if the key should retain its original dtype + keep_original_dtype = False + if keep_original_dtype_keys_regex: + for pattern in keep_original_dtype_keys_regex: + if re.match(pattern, hf_key): + keep_original_dtype = True + break + + # Converting to config's dtype + if not keep_original_dtype and hf_weight.dtype != model_config.dtype: + logger.warning( + f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}" + ) + hf_weight = hf_weight.astype(model_config.dtype) + + if hf_key.endswith(".weight"): + hf_key = hf_key.removesuffix(".weight") + + # Find the corresponding model key using the HF key + if "layers" in hf_key: + layer_num = re.search(r"layers\.(\d+)", hf_key).group(1) + layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key) + model_key = name_map[layer_key] + model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key) + elif "blocks" in hf_key: + layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1) + layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key) + model_key = name_map[layer_key] + model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key) + else: + if hf_key not in name_map and hf_key == "lm_head": + logger.warning(f"Skip loading {hf_key} due to tie_word_embeddings") + return + if hf_key not in name_map and "t2d" in hf_key: + logger.warning( + f"Skip loading {hf_key} as it's not used in eagle-3 for now") + return + model_key = name_map.get(hf_key, hf_key) + + model_weight, model_sharding = get_param_and_sharding( + params, shardings, model_key) + + logger.debug( + "before transform | " + f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}" + ) + + if hf_key.endswith(".bias"): + for key in bias_reshape_keys: + if key in hf_key: + hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key]) + if head_dim_pad > 0: + hf_weight = jnp.pad(hf_weight, ((0, 0), (0, head_dim_pad))) + break + else: + for key in reshape_keys: + if key in hf_key: + hf_weight = jnp.reshape(hf_weight, reshape_keys[key]) + if head_dim_pad > 0: + if "o_proj" in key: + hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0), + (0, head_dim_pad))) + else: + hf_weight = jnp.pad(hf_weight, + ((0, 0), (0, head_dim_pad), + (0, 0))) + break + for key in transpose_keys: + if key in hf_key: + hf_weight = jnp.transpose(hf_weight, transpose_keys[key]) + break + + # Pad num-kv-heads + if hf_key.endswith(".bias"): + for key, value in bias_pad_keys.items(): + dim = value[0] + dim_size = value[1] + if key in hf_key and dim_size != 0: + hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim) + break + else: + for key, value in pad_keys.items(): + dim = value[0] + dim_size = value[1] + if key in hf_key and dim_size != 0: + hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim) + break + + logger.debug( + "after transform | " + f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}" + ) + + if head_dim_pad == 0: + assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}" + + # Update the model weight + spec = model_weight.sharding.spec if isinstance( + model_weight.sharding, NamedSharding) else model_weight.sharding + model_weight.value = shard(hf_weight, spec) + + +def _load_hf_weights_on_thread( + vllm_config: VllmConfig, + params: nnx.State, + metadata_map: "MetadataMap", + mesh: Mesh, + weights_file: str, + filter_regex: Optional[str] = None, + keep_original_dtype_keys_regex: Optional[list[str]] = None, + exclude_regex: Optional[list[str]] = None, +): + """Loads weights from a single weights file.""" try: shardings = nnx.get_named_sharding(params, mesh) except TypeError: @@ -297,7 +412,6 @@ def _load_hf_weights_on_thread(vllm_config, for hf_key, hf_weight in model_weights_single_file_generator( weights_file, framework="flax", filter_regex=filter_regex): - # Check if the key should be excluded if exclude_regex: should_exclude = False @@ -309,148 +423,89 @@ def _load_hf_weights_on_thread(vllm_config, break if should_exclude: continue - - # Check if the key should retain its original dtype - keep_original_dtype = False - if keep_original_dtype_keys_regex: - for pattern in keep_original_dtype_keys_regex: - if re.match(pattern, hf_key): - keep_original_dtype = True - break - - # Converting to config's dtype - if not keep_original_dtype and hf_weight.dtype != model_config.dtype: - logger.warning( - f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}" - ) - hf_weight = hf_weight.astype(model_config.dtype) - - if hf_key.endswith(".weight"): - hf_key = hf_key.removesuffix(".weight") - - # Find the corresponding model key using the HF key - if "layers" in hf_key: - layer_num = re.search(r"layers\.(\d+)", hf_key).group(1) - layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key) - model_key = name_map[layer_key] - model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key) - elif "blocks" in hf_key: - layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1) - layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key) - model_key = name_map[layer_key] - model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key) - else: - if hf_key not in name_map and hf_key == "lm_head": - logger.warning( - f"Skip loading {hf_key} due to tie_word_embeddings") - continue - if hf_key not in name_map and "t2d" in hf_key: - logger.warning( - f"Skip loading {hf_key} as it's not used in eagle-3 for now" - ) - continue - model_key = name_map.get(hf_key, hf_key) - model_weight, model_sharding = get_param_and_sharding( - params, shardings, model_key) - - logger.debug( - "before transform | " - f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}" + _load_and_shard_weight( + vllm_config, + params, + shardings, + metadata_map, + mesh, + hf_key, + hf_weight, + keep_original_dtype_keys_regex, ) - if hf_key.endswith(".bias"): - for key in bias_reshape_keys: - if key in hf_key: - hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key]) - if head_dim_pad > 0: - hf_weight = jnp.pad(hf_weight, - ((0, 0), (0, head_dim_pad))) - break - else: - for key in reshape_keys: - if key in hf_key: - hf_weight = jnp.reshape(hf_weight, reshape_keys[key]) - if head_dim_pad > 0: - if "o_proj" in key: - hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0), - (0, head_dim_pad))) - else: - hf_weight = jnp.pad(hf_weight, - ((0, 0), (0, head_dim_pad), - (0, 0))) - break - for key in transpose_keys: - if key in hf_key: - hf_weight = jnp.transpose(hf_weight, transpose_keys[key]) - break - # Pad num-kv-heads - if hf_key.endswith(".bias"): - for key, value in bias_pad_keys.items(): - dim = value[0] - dim_size = value[1] - if key in hf_key and dim_size != 0: - hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim) - break - else: - for key, value in pad_keys.items(): - dim = value[0] - dim_size = value[1] - if key in hf_key and dim_size != 0: - hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim) - break +def load_hf_weights( + vllm_config: VllmConfig, + model: nnx.Module, + metadata_map: "MetadataMap", + mesh: Mesh, + filter_regex: Optional[str] = None, + is_draft_model: bool = False, + keep_original_dtype_keys_regex: Optional[list[str]] = None, + exclude_regex: Optional[list[str]] = None, +): + """Load weights into a JAX model from either an iterator or files.""" + params = nnx.state(model) + try: + shardings = nnx.get_named_sharding(params, mesh) + except TypeError: + shardings = params + weights_iterator = None + if hasattr(vllm_config.model_config, "model_weights_iterator"): + weights_iterator = vllm_config.model_config.model_weights_iterator + env = torchax.default_env() + # The weights_iterator is used in RunAI model streamer integration. + if weights_iterator is not None: + for hf_key, hf_weight in weights_iterator: + if filter_regex and not re.match(filter_regex, hf_key): + continue - logger.debug( - "after transform | " - f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}" - ) + # Since the weights_iterator yields Pytorch tensors (torch.Tensor), + # we need to convert them to JAX arrays (jax.Array). + hf_weight_jax = env.t2j_copy(hf_weight) - if head_dim_pad == 0: - assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}" - - # Update the model weight - spec = model_weight.sharding.spec if isinstance( - model_weight.sharding, NamedSharding) else model_weight.sharding - model_weight.value = shard(hf_weight, spec) - - -def load_hf_weights(vllm_config, - model: nnx.Module, - metadata_map: MetadataMap, - mesh: Mesh, - filter_regex: str | None = None, - is_draft_model: bool = False, - keep_original_dtype_keys_regex: list[str] | None = None, - exclude_regex: list[str] | None = None): - """Load weights from all model weights files to the model, run in multi threads.""" - if is_draft_model: - model_path = vllm_config.speculative_config.draft_model_config.model - else: - model_path = vllm_config.model_config.model - weights_files = get_model_weights_files( - model_path, vllm_config.load_config.download_dir) - params = nnx.state(model) - max_workers = min(64, len(weights_files)) - # NOTE(xiang): Disable multi-threading mode if running on multi-host. - # Because multi-threading would cause different JAX processes to load - # different weights at the same time. - if envs.TPU_MULTIHOST_BACKEND == "ray": - max_workers = 1 - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [ - executor.submit( - _load_hf_weights_on_thread, + _load_and_shard_weight( vllm_config, params, + shardings, metadata_map, mesh, - weights_file, - filter_regex=filter_regex, - keep_original_dtype_keys_regex=keep_original_dtype_keys_regex, - exclude_regex=exclude_regex) for weights_file in weights_files - ] - for future in futures: - future.result() + hf_key, + hf_weight_jax, + keep_original_dtype_keys_regex, + ) + else: + # File-based path (multi-threaded) + if is_draft_model: + model_path = vllm_config.speculative_config.draft_model_config.model + else: + model_path = vllm_config.model_config.model + weights_files = get_model_weights_files( + model_path, vllm_config.load_config.download_dir) + max_workers = min(64, len(weights_files)) + # NOTE(xiang): Disable multi-threading mode if running on multi-host. + # Because multi-threading would cause different JAX processes to load + # different weights at the same time. + if envs.TPU_MULTIHOST_BACKEND == "ray": + max_workers = 1 + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit(_load_hf_weights_on_thread, + vllm_config, + params, + metadata_map, + mesh, + weights_file, + filter_regex=filter_regex, + keep_original_dtype_keys_regex= + keep_original_dtype_keys_regex, + exclude_regex=exclude_regex) + for weights_file in weights_files + ] + for future in futures: + future.result() + check_all_loaded(params) nnx.update(model, params)