From d3b1d45f0d93eb08178ce226dcdcc3722230926e Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 8 May 2024 15:07:46 -0400 Subject: [PATCH 01/20] implement sharded state loader --- examples/save_sharded_state.py | 67 ++++++++++++++++ vllm/config.py | 1 + vllm/executor/distributed_gpu_executor.py | 10 +++ vllm/model_executor/model_loader/loader.py | 91 ++++++++++++++++++++++ vllm/worker/model_runner.py | 14 ++++ vllm/worker/worker.py | 12 +++ 6 files changed, 195 insertions(+) create mode 100644 examples/save_sharded_state.py diff --git a/examples/save_sharded_state.py b/examples/save_sharded_state.py new file mode 100644 index 00000000000..a3a19d8c622 --- /dev/null +++ b/examples/save_sharded_state.py @@ -0,0 +1,67 @@ +import argparse +import dataclasses +import os +import shutil +from pathlib import Path + +from vllm import LLM, EngineArgs + +""" +Saves each worker's model state dict directly to a checkpoint, which enables a +fast load path for large tensor-parallel models where each worker only needs to +read its own shard rather than the entire checkpoint. + +Example usage: + +python save_sharded_state.py \ + --model /path/to/load \ + --quantization deepspeedfp \ + --tensor-parallel-size 8 \ + --output /path/to/save + +Then, the model can be loaded with + +llm = LLM( + model="/path/to/save", + load_format="sharded_state", + quantization="deepspeedfp", + tensor_parallel_size=8, +) +""" + +parser = argparse.ArgumentParser() +EngineArgs.add_cli_args(parser) +parser.add_argument("--output", + "-o", + required=True, + type=str, + help="path to output checkpoint") +parser.add_argument("--pattern", + type=str, + help="string pattern of saved filenames") + + +def main(args): + engine_args = EngineArgs.from_cli_args(args) + model_path = engine_args.model + if not Path(model_path).is_dir(): + raise ValueError("model path must be a local directory") + # Create LLM instance from arguments + llm = LLM(**dataclasses.asdict(engine_args)) + # Prepare output directory + Path(args.output).mkdir(exist_ok=True) + # Dump worker states to output directory + model_executor = llm.llm_engine.model_executor + model_executor.save_sharded_state(path=args.output, + pattern=args.pattern, + max_size=5 * 1024**3) + # Copy metadata files to output directory + for file in os.listdir(model_path): + if not any( + file.endswith(ext) for ext in (".bin", ".pt", ".safetensors")): + shutil.copy(f"{model_path}/{file}", args.output) + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/vllm/config.py b/vllm/config.py index 5c3a8615eef..94d36705052 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -455,6 +455,7 @@ class LoadFormat(str, enum.Enum): NPCACHE = "npcache" DUMMY = "dummy" TENSORIZER = "tensorizer" + SHARDED_STATE = "sharded_state" @dataclass diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 4c922ef63ee..86f783cf6af 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -77,6 +77,16 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self._run_workers("list_loras") + def save_sharded_state(self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self._run_workers("save_sharded_state", + path=path, + pattern=pattern, + max_size=max_size) + @abstractmethod def _run_workers( self, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bafa2de62e5..614204ce2bc 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -347,6 +347,94 @@ def load_model(self, *, model_config: ModelConfig, vision_language_config) +class ShardedStateLoader(BaseModelLoader): + """ + Model loader that directly loads each worker's model state dict, which + enables a fast load path for large tensor-parallel models where each worker + only needs to read its own shard rather than the entire checkpoint. See + `examples/save_sharded_states.py` for creating a sharded checkpoint. + """ + + DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + extra_config = ({} if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy()) + self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) + if extra_config: + raise ValueError(f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}") + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + from safetensors.torch import load_file + + from vllm.distributed import get_tensor_model_parallel_rank + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config) + rank = get_tensor_model_parallel_rank() + pattern = os.path.join( + model_config.model, + self.pattern.format(rank=rank, part="*"), + ) + filepaths = glob.glob(pattern) + if not filepaths: + # TODO: support un-sharded checkpoints too + raise ValueError( + f"Could not find checkpoint files '{pattern}', only " + f"pre-sharded checkpoints are currently supported!" + ) + state_dict = dict(model.state_dict()) + for path in filepaths: + for key, val in load_file(path).items(): + state_dict[key].copy_(val) + state_dict.pop(key) + assert not state_dict + return model.eval() + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from safetensors.torch import save_file + + from vllm.distributed import get_tensor_model_parallel_rank + if pattern is None: + pattern = ShardedStateLoader.DEFAULT_PATTERN + rank = get_tensor_model_parallel_rank() + part = 0 + total_size = 0 + state_dict: Dict[str, torch.Tensor] = {} + for name, tensor in model.state_dict().items(): + param_size = tensor.nelement() * tensor.element_size() + if max_size is not None and total_size + param_size > max_size: + save_file( + state_dict, + os.path.join(path, pattern.format(rank=rank, part=part)), + ) + part += 1 + total_size = 0 + state_dict = {} + state_dict[name] = tensor + total_size += param_size + if len(state_dict) > 0: + save_file( + state_dict, + os.path.join(path, pattern.format(rank=rank, part=part)), + ) + + def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" @@ -359,4 +447,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.TENSORIZER: return TensorizerLoader(load_config) + if load_config.load_format == LoadFormat.SHARDED_STATE: + return ShardedStateLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 46c6730645c..5d17eb828d1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -211,6 +211,20 @@ def load_model(self) -> None: "but the KV cache data type is not FP8. " "KV cache scaling factors will not be used.") + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from vllm.model_executor.model_loader.loader import ShardedStateLoader + ShardedStateLoader.save_model( + self.model, + path, + pattern=pattern, + max_size=max_size, + ) + def set_block_size(self, block_size: int) -> None: self.block_size = block_size diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 538332ad003..2d3ddad8abc 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -117,6 +117,18 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.model_runner.save_sharded_state( + path, + pattern=pattern, + max_size=max_size, + ) + @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many From 8acd8b4e778af453276684e7733b99e1e820db7d Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 9 May 2024 18:34:02 -0400 Subject: [PATCH 02/20] add test --- tests/test_sharded_state_loader.py | 55 ++++++++++++++++++++++ vllm/model_executor/model_loader/loader.py | 13 ++++- 2 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 tests/test_sharded_state_loader.py diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py new file mode 100644 index 00000000000..2ea94646d79 --- /dev/null +++ b/tests/test_sharded_state_loader.py @@ -0,0 +1,55 @@ +import os +import shutil +from tempfile import TemporaryDirectory + +from vllm import LLM, SamplingParams +from huggingface_hub import snapshot_download + + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) + + +def test_sharded_state_loader(): + weights_patterns = ("*.bin", "*.pt", "*.safetensors") + + cache_dir = TemporaryDirectory() + input_dir = snapshot_download("facebook/opt-125m", + cache_dir=cache_dir.name) + + llm_before = LLM( + model=input_dir, + worker_use_ray=True, + gpu_memory_utilization=0.1, + ) + gen_before = llm_before.generate(prompts, sampling_params) + out_before = [gen.outputs[0].__dict__ for gen in gen_before] + + # Dump worker states to output directory + model_executor = llm_before.llm_engine.model_executor + output_dir = TemporaryDirectory() + model_executor.save_sharded_state(path=output_dir.name) + # Copy metadata files to output directory + for file in os.listdir(input_dir): + if not any(file.endswith(ext) for ext in weights_patterns): + shutil.copy(f"{input_dir}/{file}", output_dir.name) + + cache_dir.cleanup() + + llm_after = LLM( + model=output_dir.name, + worker_use_ray=True, + gpu_memory_utilization=0.1, + load_format="sharded_state", + ) + gen_after = llm_after.generate(prompts, sampling_params) + out_after = [gen.outputs[0].__dict__ for gen in gen_after] + + assert out_before == out_after \ No newline at end of file diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 614204ce2bc..a801cf2e725 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -393,11 +393,17 @@ def load_model(self, *, model_config: ModelConfig, f"pre-sharded checkpoints are currently supported!" ) state_dict = dict(model.state_dict()) + data_ptrs = {} for path in filepaths: for key, val in load_file(path).items(): + data_ptrs[state_dict[key].data_ptr()] = key state_dict[key].copy_(val) state_dict.pop(key) - assert not state_dict + for key, val in state_dict.items(): + if val.data_ptr() in data_ptrs: + logger.warning(f"Skipping loading shared tensor '{key}'") + else: + raise ValueError(f"Missing key '{key}' in loaded state!") return model.eval() @staticmethod @@ -416,7 +422,12 @@ def save_model( part = 0 total_size = 0 state_dict: Dict[str, torch.Tensor] = {} + data_ptrs = {} for name, tensor in model.state_dict().items(): + if tensor.data_ptr() in data_ptrs: + logger.warning(f"Skipping saving shared tensor '{name}'") + continue + data_ptrs[tensor.data_ptr()] = name param_size = tensor.nelement() * tensor.element_size() if max_size is not None and total_size + param_size > max_size: save_file( From 217890bddc18ec39065a3efcec779ae4505e86bb Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 9 May 2024 18:41:29 -0400 Subject: [PATCH 03/20] small --- tests/test_sharded_state_loader.py | 2 +- vllm/model_executor/model_loader/loader.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index 2ea94646d79..b30a2afdaf6 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -52,4 +52,4 @@ def test_sharded_state_loader(): gen_after = llm_after.generate(prompts, sampling_params) out_after = [gen.outputs[0].__dict__ for gen in gen_after] - assert out_before == out_after \ No newline at end of file + assert out_before == out_after diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index a801cf2e725..c4ae1e5f098 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -401,7 +401,7 @@ def load_model(self, *, model_config: ModelConfig, state_dict.pop(key) for key, val in state_dict.items(): if val.data_ptr() in data_ptrs: - logger.warning(f"Skipping loading shared tensor '{key}'") + logger.warning("Skipping loading shared tensor '%s'", key) else: raise ValueError(f"Missing key '{key}' in loaded state!") return model.eval() @@ -425,7 +425,7 @@ def save_model( data_ptrs = {} for name, tensor in model.state_dict().items(): if tensor.data_ptr() in data_ptrs: - logger.warning(f"Skipping saving shared tensor '{name}'") + logger.warning("Skipping saving shared tensor '%s'", name) continue data_ptrs[tensor.data_ptr()] = name param_size = tensor.nelement() * tensor.element_size() From a1b9e2a5b1d28808de6e90b2be48c3cd0aea8eb2 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 9 May 2024 18:42:30 -0400 Subject: [PATCH 04/20] small --- tests/test_sharded_state_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index b30a2afdaf6..dca20315739 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -2,9 +2,9 @@ import shutil from tempfile import TemporaryDirectory -from vllm import LLM, SamplingParams from huggingface_hub import snapshot_download +from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", From a4b11162f892f2dc7c90a2984e2d376776b0f510 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 9 May 2024 22:44:55 -0400 Subject: [PATCH 05/20] review --- examples/save_sharded_state.py | 13 ++-- tests/test_sharded_state_loader.py | 89 +++++++++++++--------- vllm/model_executor/model_loader/loader.py | 69 ++++++++++------- 3 files changed, 106 insertions(+), 65 deletions(-) diff --git a/examples/save_sharded_state.py b/examples/save_sharded_state.py index a3a19d8c622..eafddc61611 100644 --- a/examples/save_sharded_state.py +++ b/examples/save_sharded_state.py @@ -36,9 +36,13 @@ required=True, type=str, help="path to output checkpoint") -parser.add_argument("--pattern", +parser.add_argument("--file-pattern", type=str, help="string pattern of saved filenames") +parser.add_argument("--max-file-size", + type=str, + default=5 * 1024 ** 3, + help="max size (in bytes) of each safetensors file") def main(args): @@ -53,12 +57,11 @@ def main(args): # Dump worker states to output directory model_executor = llm.llm_engine.model_executor model_executor.save_sharded_state(path=args.output, - pattern=args.pattern, - max_size=5 * 1024**3) + pattern=args.file_pattern, + max_size=args.max_file_size) # Copy metadata files to output directory for file in os.listdir(model_path): - if not any( - file.endswith(ext) for ext in (".bin", ".pt", ".safetensors")): + if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): shutil.copy(f"{model_path}/{file}", args.output) diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index dca20315739..b2329eb65b9 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -2,9 +2,11 @@ import shutil from tempfile import TemporaryDirectory +import torch from huggingface_hub import snapshot_download from vllm import LLM, SamplingParams +from vllm.model_executor.model_loader.loader import ShardedStateLoader prompts = [ "Hello, my name is", @@ -14,42 +16,61 @@ ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) +sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + seed=0, + max_tokens=256, + ignore_eos=True, +) + + +def test_filter_subtensors(): + state_dict = { + "a": torch.empty(2), + "b": torch.empty((2, 4)), + "c": torch.empty((2, 4, 8)), + } + state_dict.update({ + "x": state_dict["b"], + "y": state_dict["c"][1, 2, :], + "z": state_dict["c"][1, :, 4], + }) + filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict) + assert tuple(filtered_state_dict.keys()) == ("a", "b", "c") + for key, tensor in filtered_state_dict.items(): + assert tensor.equal(state_dict[key]) def test_sharded_state_loader(): weights_patterns = ("*.bin", "*.pt", "*.safetensors") - cache_dir = TemporaryDirectory() - input_dir = snapshot_download("facebook/opt-125m", - cache_dir=cache_dir.name) - - llm_before = LLM( - model=input_dir, - worker_use_ray=True, - gpu_memory_utilization=0.1, - ) - gen_before = llm_before.generate(prompts, sampling_params) - out_before = [gen.outputs[0].__dict__ for gen in gen_before] - - # Dump worker states to output directory - model_executor = llm_before.llm_engine.model_executor - output_dir = TemporaryDirectory() - model_executor.save_sharded_state(path=output_dir.name) - # Copy metadata files to output directory - for file in os.listdir(input_dir): - if not any(file.endswith(ext) for ext in weights_patterns): - shutil.copy(f"{input_dir}/{file}", output_dir.name) - - cache_dir.cleanup() - - llm_after = LLM( - model=output_dir.name, - worker_use_ray=True, - gpu_memory_utilization=0.1, - load_format="sharded_state", - ) - gen_after = llm_after.generate(prompts, sampling_params) - out_after = [gen.outputs[0].__dict__ for gen in gen_after] - - assert out_before == out_after + with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir: + input_dir = snapshot_download("facebook/opt-125m", cache_dir=cache_dir) + + llm_before = LLM( + model=input_dir, + worker_use_ray=True, + gpu_memory_utilization=0.1, + ) + gen_before = llm_before.generate(prompts, sampling_params) + out_before = [gen.outputs[0].__dict__ for gen in gen_before] + + # Dump worker states to output directory + model_executor = llm_before.llm_engine.model_executor + model_executor.save_sharded_state(path=output_dir) + # Copy metadata files to output directory + for file in os.listdir(input_dir): + if not any(file.endswith(ext) for ext in weights_patterns): + shutil.copy(f"{input_dir}/{file}", output_dir) + + llm_after = LLM( + model=output_dir, + worker_use_ray=True, + gpu_memory_utilization=0.1, + load_format="sharded_state", + ) + gen_after = llm_after.generate(prompts, sampling_params) + out_after = [gen.outputs[0].__dict__ for gen in gen_after] + + assert out_before == out_after diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index c4ae1e5f098..7f80cae2391 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -367,6 +367,31 @@ def __init__(self, load_config: LoadConfig): f"{load_config.load_format}: " f"{load_config.model_loader_extra_config.keys()}") + @staticmethod + def _filter_subtensors(tensors: Dict[str, torch.Tensor]): + """ + Filter out all tensors that share the same memory or a subset of the + memory of another tensor. + """ + from safetensors.torch import storage_ptr, storage_size + result = {} + for key1, tensor1 in tensors.items(): + a1 = storage_ptr(tensor1) # tensor1 start + b1 = a1 + storage_size(tensor1) # tensor1 end + for key2, tensor2 in tensors.items(): + a2 = storage_ptr(tensor2) # tensor2 start + b2 = a2 + storage_size(tensor2) #tensor2 end + if (a1, b1) == (a2, b2): + # Same memory, take only the first key (lexicographically). + if key2 < key1: + break + elif a1 <= a2 and b2 <= b1: + # tensor1 is a subtensor of tensor2. + break + else: + result[key1] = tensor1 + return result + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -392,18 +417,14 @@ def load_model(self, *, model_config: ModelConfig, f"Could not find checkpoint files '{pattern}', only " f"pre-sharded checkpoints are currently supported!" ) - state_dict = dict(model.state_dict()) - data_ptrs = {} + state_dict = self._filter_subtensors(model.state_dict()) for path in filepaths: - for key, val in load_file(path).items(): - data_ptrs[state_dict[key].data_ptr()] = key - state_dict[key].copy_(val) + for key, tensor in load_file(path).items(): + state_dict[key].copy_(tensor) state_dict.pop(key) - for key, val in state_dict.items(): - if val.data_ptr() in data_ptrs: - logger.warning("Skipping loading shared tensor '%s'", key) - else: - raise ValueError(f"Missing key '{key}' in loaded state!") + if state_dict: + raise ValueError( + f"Missing keys {tuple(state_dict)} in loaded state!") return model.eval() @staticmethod @@ -419,30 +440,26 @@ def save_model( if pattern is None: pattern = ShardedStateLoader.DEFAULT_PATTERN rank = get_tensor_model_parallel_rank() - part = 0 + part_idx = 0 total_size = 0 - state_dict: Dict[str, torch.Tensor] = {} - data_ptrs = {} - for name, tensor in model.state_dict().items(): - if tensor.data_ptr() in data_ptrs: - logger.warning("Skipping saving shared tensor '%s'", name) - continue - data_ptrs[tensor.data_ptr()] = name + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + state_dict_part: Dict[str, torch.Tensor] = {} + for key, tensor in state_dict.items(): param_size = tensor.nelement() * tensor.element_size() if max_size is not None and total_size + param_size > max_size: save_file( - state_dict, - os.path.join(path, pattern.format(rank=rank, part=part)), + state_dict_part, + os.path.join(path, pattern.format(rank=rank, part=part_idx)), ) - part += 1 + part_idx += 1 total_size = 0 - state_dict = {} - state_dict[name] = tensor + state_dict_part = {} + state_dict_part[key] = tensor total_size += param_size - if len(state_dict) > 0: + if len(state_dict_part) > 0: save_file( - state_dict, - os.path.join(path, pattern.format(rank=rank, part=part)), + state_dict_part, + os.path.join(path, pattern.format(rank=rank, part=part_idx)), ) From 49bebb383c6b161e2a4eeec0d038e41c8f0e264f Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 9 May 2024 22:48:07 -0400 Subject: [PATCH 06/20] ruff --- vllm/model_executor/model_loader/loader.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 7f80cae2391..8489ce67ae5 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -447,9 +447,10 @@ def save_model( for key, tensor in state_dict.items(): param_size = tensor.nelement() * tensor.element_size() if max_size is not None and total_size + param_size > max_size: + filename = pattern.format(rank=rank, part=part_idx) save_file( state_dict_part, - os.path.join(path, pattern.format(rank=rank, part=part_idx)), + os.path.join(path, filename), ) part_idx += 1 total_size = 0 @@ -457,9 +458,10 @@ def save_model( state_dict_part[key] = tensor total_size += param_size if len(state_dict_part) > 0: + filename = pattern.format(rank=rank, part=part_idx) save_file( state_dict_part, - os.path.join(path, pattern.format(rank=rank, part=part_idx)), + os.path.join(path, filename), ) From 4567b701077d220ca30b342f7ec16d41eac70f39 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 10 May 2024 15:46:49 -0400 Subject: [PATCH 07/20] review --- vllm/model_executor/model_loader/loader.py | 34 +++++++++------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 8489ce67ae5..a6339a0185f 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -374,23 +374,16 @@ def _filter_subtensors(tensors: Dict[str, torch.Tensor]): memory of another tensor. """ from safetensors.torch import storage_ptr, storage_size - result = {} - for key1, tensor1 in tensors.items(): - a1 = storage_ptr(tensor1) # tensor1 start - b1 = a1 + storage_size(tensor1) # tensor1 end - for key2, tensor2 in tensors.items(): - a2 = storage_ptr(tensor2) # tensor2 start - b2 = a2 + storage_size(tensor2) #tensor2 end - if (a1, b1) == (a2, b2): - # Same memory, take only the first key (lexicographically). - if key2 < key1: - break - elif a1 <= a2 and b2 <= b1: - # tensor1 is a subtensor of tensor2. + tensors = tensors.copy() + starts = sorted([(storage_ptr(t), k) for k, t in tensors.items()]) + stops = sorted([(start + storage_size(tensors[key]), start_idx) + for start_idx, (start, key) in enumerate(starts)]) + for stop, start_idx in stops: + for i in range(start_idx + 1, len(starts)): + if starts[i][0] >= stop: break - else: - result[key1] = tensor1 - return result + tensors.pop(starts[i][1], None) + return tensors def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, @@ -398,7 +391,7 @@ def load_model(self, *, model_config: ModelConfig, vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: - from safetensors.torch import load_file + from safetensors import safe_open from vllm.distributed import get_tensor_model_parallel_rank with set_default_torch_dtype(model_config.dtype): @@ -419,9 +412,10 @@ def load_model(self, *, model_config: ModelConfig, ) state_dict = self._filter_subtensors(model.state_dict()) for path in filepaths: - for key, tensor in load_file(path).items(): - state_dict[key].copy_(tensor) - state_dict.pop(key) + with safe_open(path, framework="pt") as f: + for key in f.keys(): + state_dict[key].copy_(f.get_tensor(key)) + state_dict.pop(key) if state_dict: raise ValueError( f"Missing keys {tuple(state_dict)} in loaded state!") From 72b7148e0033378ef48197e77782c1e1c2ff6805 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 10 May 2024 15:52:57 -0400 Subject: [PATCH 08/20] ruff --- vllm/model_executor/model_loader/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index a6339a0185f..f2069002b86 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -413,7 +413,7 @@ def load_model(self, *, model_config: ModelConfig, state_dict = self._filter_subtensors(model.state_dict()) for path in filepaths: with safe_open(path, framework="pt") as f: - for key in f.keys(): + for key in f: # noqa: SIM118 state_dict[key].copy_(f.get_tensor(key)) state_dict.pop(key) if state_dict: From 2f5c2e04af4d178a5079074014d74a9b38841bed Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 10 May 2024 19:32:29 -0400 Subject: [PATCH 09/20] update --- vllm/model_executor/model_loader/loader.py | 45 +++++++++++++++------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index f2069002b86..b5d06fce2fe 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1,4 +1,5 @@ # ruff: noqa: SIM117 +import collections import copy import glob import os @@ -368,22 +369,40 @@ def __init__(self, load_config: LoadConfig): f"{load_config.model_loader_extra_config.keys()}") @staticmethod - def _filter_subtensors(tensors: Dict[str, torch.Tensor]): + def _filter_subtensors( + tensors: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: """ Filter out all tensors that share the same memory or a subset of the memory of another tensor. """ - from safetensors.torch import storage_ptr, storage_size - tensors = tensors.copy() - starts = sorted([(storage_ptr(t), k) for k, t in tensors.items()]) - stops = sorted([(start + storage_size(tensors[key]), start_idx) - for start_idx, (start, key) in enumerate(starts)]) - for stop, start_idx in stops: - for i in range(start_idx + 1, len(starts)): - if starts[i][0] >= stop: - break - tensors.pop(starts[i][1], None) - return tensors + same_storage_groups = collections.defaultdict(list) + for key, tensor in tensors.items(): + if tensor.numel(): + ptr = tensor.untyped_storage().data_ptr() + same_storage_groups[tensor.device, ptr].append((key, tensor)) + + def get_end_ptr(tensor: torch.Tensor) -> int: + return tensor.view(-1)[-1].data_ptr() + tensor.element_size() + + result = {} + for group in same_storage_groups.values(): + for k, t in group: + a, b = t.data_ptr(), get_end_ptr(t) + for k2, t2 in group: + if not t2.is_contiguous(): + continue + a2, b2 = t2.data_ptr(), get_end_ptr(t2) + if a < a2 or b2 < b: + continue + if a2 < a or b < b2 or not t.is_contiguous(): + break # t2 covers strictly more memory than t. + if k2 < k: + # Same tensors, keep the one with the smaller key. + break + else: + result[k] = t + return result def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, @@ -413,7 +432,7 @@ def load_model(self, *, model_config: ModelConfig, state_dict = self._filter_subtensors(model.state_dict()) for path in filepaths: with safe_open(path, framework="pt") as f: - for key in f: # noqa: SIM118 + for key in f.keys(): # noqa: SIM118 state_dict[key].copy_(f.get_tensor(key)) state_dict.pop(key) if state_dict: From 0fb888563d5c1e5b45d03647fc2ad956618190f2 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 13 May 2024 15:28:13 -0400 Subject: [PATCH 10/20] narrow model shape --- vllm/model_executor/model_loader/loader.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b5d06fce2fe..368748af3b2 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -433,7 +433,12 @@ def load_model(self, *, model_config: ModelConfig, for path in filepaths: with safe_open(path, framework="pt") as f: for key in f.keys(): # noqa: SIM118 - state_dict[key].copy_(f.get_tensor(key)) + tensor = f.get_tensor(key) + for dim, size in enumerate(tensor.shape): + state_dict[key].data = ( + state_dict[key].data.narrow(dim, 0, size) + ) + state_dict[key].data.copy_(tensor) state_dict.pop(key) if state_dict: raise ValueError( From 006ad5998649a935738e21da5d264d97433cb1ce Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 13 May 2024 17:48:32 -0400 Subject: [PATCH 11/20] update --- tests/test_sharded_state_loader.py | 24 +++++++++++++++------- vllm/model_executor/model_loader/loader.py | 16 +++++++++++---- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index b2329eb65b9..a406ee733a4 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -2,6 +2,7 @@ import shutil from tempfile import TemporaryDirectory +import pytest import torch from huggingface_hub import snapshot_download @@ -42,32 +43,41 @@ def test_filter_subtensors(): assert tensor.equal(state_dict[key]) -def test_sharded_state_loader(): +@pytest.mark.parametrize("enable_lora", [True]) +def test_sharded_state_loader(enable_lora): weights_patterns = ("*.bin", "*.pt", "*.safetensors") with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir: input_dir = snapshot_download("facebook/opt-125m", cache_dir=cache_dir) - llm_before = LLM( + llm = LLM( model=input_dir, worker_use_ray=True, - gpu_memory_utilization=0.1, + gpu_memory_utilization=0.3, ) - gen_before = llm_before.generate(prompts, sampling_params) - out_before = [gen.outputs[0].__dict__ for gen in gen_before] # Dump worker states to output directory - model_executor = llm_before.llm_engine.model_executor + model_executor = llm.llm_engine.model_executor model_executor.save_sharded_state(path=output_dir) # Copy metadata files to output directory for file in os.listdir(input_dir): if not any(file.endswith(ext) for ext in weights_patterns): shutil.copy(f"{input_dir}/{file}", output_dir) + llm_before = LLM( + model=input_dir, + worker_use_ray=True, + enable_lora=enable_lora, + gpu_memory_utilization=0.3, + ) + gen_before = llm_before.generate(prompts, sampling_params) + out_before = [gen.outputs[0].__dict__ for gen in gen_before] + llm_after = LLM( model=output_dir, worker_use_ray=True, - gpu_memory_utilization=0.1, + enable_lora=enable_lora, + gpu_memory_utilization=0.3, load_format="sharded_state", ) gen_after = llm_after.generate(prompts, sampling_params) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 368748af3b2..39c36579e31 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -434,11 +434,19 @@ def load_model(self, *, model_config: ModelConfig, with safe_open(path, framework="pt") as f: for key in f.keys(): # noqa: SIM118 tensor = f.get_tensor(key) + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape for dim, size in enumerate(tensor.shape): - state_dict[key].data = ( - state_dict[key].data.narrow(dim, 0, size) - ) - state_dict[key].data.copy_(tensor) + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning("loading tensor of shape %s into " + "parameter '%s' of shape %s", + tensor.shape, key, param_shape) + param_data.copy_(tensor) state_dict.pop(key) if state_dict: raise ValueError( From 2c3dd81dcb13a934d41e4afa59546c95c7da8ca2 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 13 May 2024 17:48:44 -0400 Subject: [PATCH 12/20] fix --- tests/test_sharded_state_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index a406ee733a4..ecca7fee1a2 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -43,7 +43,7 @@ def test_filter_subtensors(): assert tensor.equal(state_dict[key]) -@pytest.mark.parametrize("enable_lora", [True]) +@pytest.mark.parametrize("enable_lora", [False]) def test_sharded_state_loader(enable_lora): weights_patterns = ("*.bin", "*.pt", "*.safetensors") From f9cf33e1f66d3003515801365d00d6e88b0da13a Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 13 May 2024 17:53:09 -0400 Subject: [PATCH 13/20] add exception --- examples/save_sharded_state.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/save_sharded_state.py b/examples/save_sharded_state.py index eafddc61611..09090899174 100644 --- a/examples/save_sharded_state.py +++ b/examples/save_sharded_state.py @@ -47,6 +47,8 @@ def main(args): engine_args = EngineArgs.from_cli_args(args) + if engine_args.enable_lora: + raise ValueError("Saving with enable_lora=True is not supported!") model_path = engine_args.model if not Path(model_path).is_dir(): raise ValueError("model path must be a local directory") From 00168ac3294ee1533b70b42c9374ffe46004e878 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 13 May 2024 18:49:41 -0400 Subject: [PATCH 14/20] copytree --- examples/save_sharded_state.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/save_sharded_state.py b/examples/save_sharded_state.py index 09090899174..e64b48ca6b3 100644 --- a/examples/save_sharded_state.py +++ b/examples/save_sharded_state.py @@ -64,7 +64,11 @@ def main(args): # Copy metadata files to output directory for file in os.listdir(model_path): if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): - shutil.copy(f"{model_path}/{file}", args.output) + if os.path.isdir(file): + shutil.copytree(f"{model_path}/{file}", + f"{args.output}/{file}") + else: + shutil.copy(f"{model_path}/{file}", args.output) if __name__ == "__main__": From f174112e924812fb6cfa5f1afd119da13e5ae124 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 13 May 2024 18:55:10 -0400 Subject: [PATCH 15/20] merge main --- vllm/model_executor/model_loader/loader.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 4d0438a24d4..8864db35201 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -428,14 +428,16 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: from safetensors import safe_open from vllm.distributed import get_tensor_model_parallel_rank with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) + lora_config, vision_language_config, + cache_config) rank = get_tensor_model_parallel_rank() pattern = os.path.join( model_config.model, From 03ee406f24fab5e944e6d7274fe8c33e88dc3afa Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 13 May 2024 19:25:58 -0400 Subject: [PATCH 16/20] fix --- examples/save_sharded_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/save_sharded_state.py b/examples/save_sharded_state.py index e64b48ca6b3..6591098ec19 100644 --- a/examples/save_sharded_state.py +++ b/examples/save_sharded_state.py @@ -64,7 +64,7 @@ def main(args): # Copy metadata files to output directory for file in os.listdir(model_path): if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): - if os.path.isdir(file): + if os.path.isdir(f"{model_path}/{file}"): shutil.copytree(f"{model_path}/{file}", f"{args.output}/{file}") else: From 0d06a40307d8a63eb8f278096cad9c6296f0322d Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 13 May 2024 22:58:09 -0400 Subject: [PATCH 17/20] fix test --- tests/test_sharded_state_loader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index ecca7fee1a2..8540e98da36 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -43,12 +43,13 @@ def test_filter_subtensors(): assert tensor.equal(state_dict[key]) -@pytest.mark.parametrize("enable_lora", [False]) +@pytest.mark.parametrize("enable_lora", [False, True]) def test_sharded_state_loader(enable_lora): weights_patterns = ("*.bin", "*.pt", "*.safetensors") with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir: - input_dir = snapshot_download("facebook/opt-125m", cache_dir=cache_dir) + input_dir = snapshot_download("meta-llama/Llama-2-7b-hf", + cache_dir=cache_dir) llm = LLM( model=input_dir, @@ -63,6 +64,7 @@ def test_sharded_state_loader(enable_lora): for file in os.listdir(input_dir): if not any(file.endswith(ext) for ext in weights_patterns): shutil.copy(f"{input_dir}/{file}", output_dir) + del llm.llm_engine.model_executor llm_before = LLM( model=input_dir, @@ -72,6 +74,7 @@ def test_sharded_state_loader(enable_lora): ) gen_before = llm_before.generate(prompts, sampling_params) out_before = [gen.outputs[0].__dict__ for gen in gen_before] + del llm_before.llm_engine.model_executor llm_after = LLM( model=output_dir, @@ -82,5 +85,6 @@ def test_sharded_state_loader(enable_lora): ) gen_after = llm_after.generate(prompts, sampling_params) out_after = [gen.outputs[0].__dict__ for gen in gen_after] + del llm_after.llm_engine.model_executor assert out_before == out_after From b7f17411c485939c703d65bad9075a37f6523eb7 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 13 May 2024 23:12:14 -0400 Subject: [PATCH 18/20] os.path.join --- examples/save_sharded_state.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/save_sharded_state.py b/examples/save_sharded_state.py index 6591098ec19..a3ea52e7457 100644 --- a/examples/save_sharded_state.py +++ b/examples/save_sharded_state.py @@ -64,11 +64,11 @@ def main(args): # Copy metadata files to output directory for file in os.listdir(model_path): if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): - if os.path.isdir(f"{model_path}/{file}"): - shutil.copytree(f"{model_path}/{file}", - f"{args.output}/{file}") + if os.path.isdir(os.path.join(model_path, file)): + shutil.copytree(os.path.join(model_path, file), + os.path.join(args.output, file)) else: - shutil.copy(f"{model_path}/{file}", args.output) + shutil.copy(os.path.join(model_path, file), args.output) if __name__ == "__main__": From ee6739ffa588306a7b4327ae05fb47e6af137e26 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 16 May 2024 04:11:13 +0000 Subject: [PATCH 19/20] yapf --- examples/save_sharded_state.py | 17 ++++++++--------- vllm/executor/distributed_gpu_executor.py | 3 ++- vllm/model_executor/model_loader/loader.py | 13 ++++++------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/examples/save_sharded_state.py b/examples/save_sharded_state.py index a3ea52e7457..c595d98ba27 100644 --- a/examples/save_sharded_state.py +++ b/examples/save_sharded_state.py @@ -1,11 +1,3 @@ -import argparse -import dataclasses -import os -import shutil -from pathlib import Path - -from vllm import LLM, EngineArgs - """ Saves each worker's model state dict directly to a checkpoint, which enables a fast load path for large tensor-parallel models where each worker only needs to @@ -28,6 +20,13 @@ tensor_parallel_size=8, ) """ +import argparse +import dataclasses +import os +import shutil +from pathlib import Path + +from vllm import LLM, EngineArgs parser = argparse.ArgumentParser() EngineArgs.add_cli_args(parser) @@ -41,7 +40,7 @@ help="string pattern of saved filenames") parser.add_argument("--max-file-size", type=str, - default=5 * 1024 ** 3, + default=5 * 1024**3, help="max size (in bytes) of each safetensors file") diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 86f783cf6af..c5b1e61112a 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -77,7 +77,8 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self._run_workers("list_loras") - def save_sharded_state(self, + def save_sharded_state( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 8864db35201..e19c4d256ec 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -389,8 +389,7 @@ def __init__(self, load_config: LoadConfig): @staticmethod def _filter_subtensors( - tensors: Dict[str, torch.Tensor], - ) -> Dict[str, torch.Tensor]: + tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Filter out all tensors that share the same memory or a subset of the memory of another tensor. @@ -448,8 +447,7 @@ def load_model(self, *, model_config: ModelConfig, # TODO: support un-sharded checkpoints too raise ValueError( f"Could not find checkpoint files '{pattern}', only " - f"pre-sharded checkpoints are currently supported!" - ) + f"pre-sharded checkpoints are currently supported!") state_dict = self._filter_subtensors(model.state_dict()) for path in filepaths: with safe_open(path, framework="pt") as f: @@ -464,9 +462,10 @@ def load_model(self, *, model_config: ModelConfig, if size < param_shape[dim]: param_data = param_data.narrow(dim, 0, size) if tensor.shape != param_shape: - logger.warning("loading tensor of shape %s into " - "parameter '%s' of shape %s", - tensor.shape, key, param_shape) + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", tensor.shape, + key, param_shape) param_data.copy_(tensor) state_dict.pop(key) if state_dict: From 1aafeaf857732b44f50a6b4ef8e9529a8e6d6959 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 16 May 2024 04:11:33 +0000 Subject: [PATCH 20/20] safetensors -> safetensors.torch --- vllm/model_executor/model_loader/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index e19c4d256ec..dc568928b28 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -429,7 +429,7 @@ def load_model(self, *, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: - from safetensors import safe_open + from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank with set_default_torch_dtype(model_config.dtype):