Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .buildkite/features/runai_model_streamer_loader.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# runai_model_streamer_loader
# The RunAI Model Streamer is a high-performance model loader that serves as an
# alternative to the default Hugging Face loader. Instead of downloading a model
# to local disk, it streams the weights from object storage (like GCS) into
# GPU memory. This streaming process is significantly faster than the traditional
# disk-based loading method.
steps:
- label: "Correctness tests for runai_model_streamer_loader"
key: "runai_model_streamer_loader_CorrectnessTest"
soft_fail: true
agents:
queue: tpu_v6e_queue
commands:
- .buildkite/scripts/run_in_docker.sh python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_runai_model_streamer_loader.py::test_correctness
- label: "Record correctness test result for runai_model_streamer_loader"
key: "record_runai_model_streamer_loader_CorrectnessTest"
depends_on: "runai_model_streamer_loader_CorrectnessTest"
env:
CI_TARGET: "runai_model_streamer_loader"
CI_STAGE: "CorrectnessTest"
CI_CATEGORY: "feature support matrix"
agents:
queue: cpu
commands:
- |
.buildkite/scripts/record_step_result.sh runai_model_streamer_loader_CorrectnessTest
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ torchvision==0.24.0
pathwaysutils
parameterized
numba==0.62.1
runai-model-streamer[s3,gcs]==0.15.0
90 changes: 90 additions & 0 deletions tests/e2e/test_runai_model_streamer_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# This file contains end-to-end tests for the RunAI Model Streamer loader.
#
# The RunAI Model Streamer is a high-performance model loader that serves as an
# alternative to the default Hugging Face loader. Instead of downloading a model
# to local disk, it streams the weights from object storage (like GCS) into
# GPU memory. This streaming process is significantly faster than the
# traditional disk-based loading method.

# The tests in this file verify that loading model weights using the
# streamer produces the same results as loading the same model using the
# standard Hugging Face loader. This ensures the correctness of the streamer
# integration.

# The tests are performed by:
# 1. Loading a model from Google Cloud Storage using the `runai_streamer` format.
# 2. Generating output with this model.
# 3. Loading the same model from Hugging Face using the default loader.
# 4. Generating output with this second model.
# 5. Asserting that the outputs from both models are identical.

from __future__ import annotations

import time

import pytest
from vllm import LLM, SamplingParams


@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=True)


@pytest.fixture
# TODO(amacaskill): Replace with GKE owned GCS bucket.
def gcs_model_name():
return "gs://vertex-model-garden-public-us/llama3/llama3-8b-hf"


@pytest.fixture
def hf_model_name():
return "meta-llama/Meta-Llama-3-8B"


@pytest.fixture
def prompt():
return "Hello, my name is"


def test_correctness(
sampling_config: SamplingParams,
gcs_model_name: str,
hf_model_name: str,
prompt: str,
monkeypatch: pytest.MonkeyPatch,
):
'''
Compare the outputs of a model loaded from GCS via runai_model_streamer
and a model loaded from Hugging Face. The outputs should be the same.
These tests attempt to use tensor_parallel_size=1. The model is 16GB,
# and v6e has 32GB of HBM, so it will fit.
'''
# Set ENV variables so that runai_model_streamer uses anonymous GCS access.
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake-project")
monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true")
monkeypatch.setenv("CLOUD_STORAGE_EMULATOR_ENDPOINT",
"https://storage.googleapis.com")
gcs_llm = LLM(model=gcs_model_name,
load_format="runai_streamer",
max_model_len=128,
max_num_seqs=16,
max_num_batched_tokens=256)
gcs_outputs = gcs_llm.generate([prompt], sampling_config)
gcs_output_text = gcs_outputs[0].outputs[0].text
del gcs_llm
time.sleep(10) # Wait for TPUs to be released

# Test with Hugging Face model
hf_llm = LLM(model=hf_model_name,
max_model_len=128,
max_num_seqs=16,
max_num_batched_tokens=256)
hf_outputs = hf_llm.generate([prompt], sampling_config)
hf_output_text = hf_outputs[0].outputs[0].text
del hf_llm
time.sleep(10) # Wait for TPUs to be released

assert gcs_output_text == hf_output_text, (
f"Outputs do not match! "
f"GCS output: {gcs_output_text}, HF output: {hf_output_text}")
1 change: 1 addition & 0 deletions tests/models/common/test_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def vllm_config() -> MagicMock:
mock_config.model_config.dtype = jnp.bfloat16
mock_config.load_config = MagicMock()
mock_config.load_config.download_dir = None
mock_config.load_config.load_format = "auto"
mock_config.additional_config = {}
mock_config.cache_config = MagicMock(cache_dtype="auto")
return mock_config
Expand Down
21 changes: 20 additions & 1 deletion tpu_inference/models/common/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from torchax.ops.mappings import j2t_dtype
from transformers import PretrainedConfig
from vllm.config import VllmConfig
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.runai_streamer_loader import \
RunaiModelStreamerLoader
from vllm.utils.func_utils import supports_kw

from tpu_inference import envs
Expand Down Expand Up @@ -177,7 +180,23 @@ def create_sharded_model():
# the model creation again, otherwise the model forward will have
# non-trivial overhead in PjitFunction.
with mesh:
model.load_weights(rng)
loader = get_model_loader(vllm_config.load_config)
if isinstance(loader, RunaiModelStreamerLoader):
model_weights = vllm_config.model_config.model
if hasattr(vllm_config.model_config, "model_weights"):
model_weights = vllm_config.model_config.model_weights
weights_iterator = loader._get_weights_iterator(
model_weights, vllm_config.model_config.revision)
# We set the weights iterator at runtime, to prevent having to change
# every model's load_weights signature. This also prevents us from hitting
# a TypeError at runtime if you use the RunaiModelStreamerLoader with any
# flax_nnx model whose load_weights function does not accept the
# weights_iterator keyword argument.
vllm_config.model_config.model_weights_iterator = weights_iterator
model.load_weights(rng)
del vllm_config.model_config.model_weights_iterator
else:
model.load_weights(rng)
jit_model = create_jit_model(
model,
use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
Expand Down
Loading